diff --git a/CHANGELOG.md b/CHANGELOG.md index f0d77bd..c63318f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ # UNRELEASED - Compatible with ENet 1.3.18 - Refine some trait requirements and derives +- Reduce allocations introduced by Rust port ([#1](https://github.com/jabuwu/rusty_enet/issues/1)) +- Adjust `Socket::receive` interface to one which takes a pre-allocated buffer +- Add `MTU_MAX` constant (an alias of `ENET_PROTOCOL_MAXIMUM_MTU`) +- Add functions: + - [`Host::mtu`] + - [`Host::set_mtu`] + - [`Peer::mtu`] + - [`Peer::set_mtu`] # 0.1.0 - Initial release diff --git a/src/c/host.rs b/src/c/host.rs index cd427d5..a7d547e 100644 --- a/src/c/host.rs +++ b/src/c/host.rs @@ -27,14 +27,14 @@ pub(crate) struct ENetHost { pub(crate) total_queued: u32, pub(crate) packet_size: usize, pub(crate) header_flags: u16, - pub(crate) commands: [ENetProtocol; 32], + pub(crate) commands: [ENetProtocol; ENET_PROTOCOL_MAXIMUM_PACKET_COMMANDS as usize], pub(crate) command_count: usize, - pub(crate) buffers: [ENetBuffer; 65], + pub(crate) buffers: [ENetBuffer; ENET_BUFFER_MAXIMUM as usize], pub(crate) buffer_count: usize, pub(crate) checksum: MaybeUninit u32>>>, pub(crate) time: MaybeUninit Duration>>, pub(crate) compressor: MaybeUninit>>, - pub(crate) packet_data: [[u8; 4096]; 2], + pub(crate) packet_data: [[u8; ENET_PROTOCOL_MAXIMUM_MTU]; 2], pub(crate) received_address: MaybeUninit>, pub(crate) received_data: *mut u8, pub(crate) received_data_length: usize, diff --git a/src/c/protocol.rs b/src/c/protocol.rs index c79d22f..c4f61a5 100644 --- a/src/c/protocol.rs +++ b/src/c/protocol.rs @@ -2,16 +2,16 @@ use std::ptr::{copy_nonoverlapping, write_bytes}; use crate::{ consts::{ - ENET_HOST_BANDWIDTH_THROTTLE_INTERVAL, ENET_PEER_FREE_RELIABLE_WINDOWS, - ENET_PEER_FREE_UNSEQUENCED_WINDOWS, ENET_PEER_PACKET_LOSS_INTERVAL, - ENET_PEER_PACKET_LOSS_SCALE, ENET_PEER_PACKET_THROTTLE_COUNTER, - ENET_PEER_PACKET_THROTTLE_SCALE, ENET_PEER_RELIABLE_WINDOWS, - ENET_PEER_RELIABLE_WINDOW_SIZE, ENET_PEER_UNSEQUENCED_WINDOW_SIZE, - ENET_PEER_WINDOW_SIZE_SCALE, ENET_PROTOCOL_MAXIMUM_CHANNEL_COUNT, - ENET_PROTOCOL_MAXIMUM_FRAGMENT_COUNT, ENET_PROTOCOL_MAXIMUM_MTU, - ENET_PROTOCOL_MAXIMUM_PEER_ID, ENET_PROTOCOL_MAXIMUM_WINDOW_SIZE, - ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT, ENET_PROTOCOL_MINIMUM_MTU, - ENET_PROTOCOL_MINIMUM_WINDOW_SIZE, + ENET_BUFFER_MAXIMUM, ENET_HOST_BANDWIDTH_THROTTLE_INTERVAL, + ENET_PEER_FREE_RELIABLE_WINDOWS, ENET_PEER_FREE_UNSEQUENCED_WINDOWS, + ENET_PEER_PACKET_LOSS_INTERVAL, ENET_PEER_PACKET_LOSS_SCALE, + ENET_PEER_PACKET_THROTTLE_COUNTER, ENET_PEER_PACKET_THROTTLE_SCALE, + ENET_PEER_RELIABLE_WINDOWS, ENET_PEER_RELIABLE_WINDOW_SIZE, + ENET_PEER_UNSEQUENCED_WINDOW_SIZE, ENET_PEER_WINDOW_SIZE_SCALE, + ENET_PROTOCOL_MAXIMUM_CHANNEL_COUNT, ENET_PROTOCOL_MAXIMUM_FRAGMENT_COUNT, + ENET_PROTOCOL_MAXIMUM_MTU, ENET_PROTOCOL_MAXIMUM_PEER_ID, + ENET_PROTOCOL_MAXIMUM_WINDOW_SIZE, ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT, + ENET_PROTOCOL_MINIMUM_MTU, ENET_PROTOCOL_MINIMUM_WINDOW_SIZE, }, enet_free, enet_host_bandwidth_throttle, enet_list_clear, enet_list_insert, enet_list_remove, enet_malloc, enet_packet_destroy, enet_peer_disconnect, @@ -19,12 +19,13 @@ use crate::{ enet_peer_has_outgoing_commands, enet_peer_on_connect, enet_peer_on_disconnect, enet_peer_ping, enet_peer_queue_acknowledgement, enet_peer_queue_incoming_command, enet_peer_queue_outgoing_command, enet_peer_receive, enet_peer_reset, enet_peer_reset_queues, - enet_peer_throttle, enet_time_get, Address, ENetAcknowledgement, ENetBuffer, ENetChannel, - ENetEvent, ENetHost, ENetIncomingCommand, ENetList, ENetListIterator, ENetListNode, - ENetOutgoingCommand, ENetPeer, ENetPeerState, PacketReceived, Socket, ENET_EVENT_TYPE_CONNECT, - ENET_EVENT_TYPE_DISCONNECT, ENET_EVENT_TYPE_NONE, ENET_EVENT_TYPE_RECEIVE, - ENET_PACKET_FLAG_RELIABLE, ENET_PACKET_FLAG_SENT, ENET_PACKET_FLAG_UNRELIABLE_FRAGMENT, - ENET_PACKET_FLAG_UNSEQUENCED, ENET_PEER_FLAG_CONTINUE_SENDING, ENET_PEER_FLAG_NEEDS_DISPATCH, + enet_peer_throttle, enet_time_get, from_raw_parts_or_empty, Address, ENetAcknowledgement, + ENetBuffer, ENetChannel, ENetEvent, ENetHost, ENetIncomingCommand, ENetList, ENetListIterator, + ENetListNode, ENetOutgoingCommand, ENetPeer, ENetPeerState, PacketReceived, Socket, + ENET_EVENT_TYPE_CONNECT, ENET_EVENT_TYPE_DISCONNECT, ENET_EVENT_TYPE_NONE, + ENET_EVENT_TYPE_RECEIVE, ENET_PACKET_FLAG_RELIABLE, ENET_PACKET_FLAG_SENT, + ENET_PACKET_FLAG_UNRELIABLE_FRAGMENT, ENET_PACKET_FLAG_UNSEQUENCED, + ENET_PEER_FLAG_CONTINUE_SENDING, ENET_PEER_FLAG_NEEDS_DISPATCH, ENET_PEER_STATE_ACKNOWLEDGING_CONNECT, ENET_PEER_STATE_ACKNOWLEDGING_DISCONNECT, ENET_PEER_STATE_CONNECTED, ENET_PEER_STATE_CONNECTING, ENET_PEER_STATE_CONNECTION_PENDING, ENET_PEER_STATE_CONNECTION_SUCCEEDED, ENET_PEER_STATE_DISCONNECTED, @@ -1657,16 +1658,15 @@ unsafe fn enet_protocol_receive_incoming_commands( data_length: 0, }; buffer.data = ((*host).packet_data[0_i32 as usize]).as_mut_ptr(); - buffer.data_length = ::core::mem::size_of::<[u8; ENET_PROTOCOL_MAXIMUM_MTU as usize]>(); - let received_length = match (*host).socket.assume_init_mut().receive(buffer.data_length) { - Ok(Some((received_address, PacketReceived::Complete(received_data)))) => { - if received_data.len() <= ENET_PROTOCOL_MAXIMUM_MTU as usize { - *(*host).received_address.assume_init_mut() = Some(received_address); - copy_nonoverlapping(received_data.as_ptr(), buffer.data, received_data.len()); - received_data.len() as i32 - } else { - continue; - } + buffer.data_length = ::core::mem::size_of::<[u8; ENET_PROTOCOL_MAXIMUM_MTU]>(); + let received_length = match (*host) + .socket + .assume_init_mut() + .receive(&mut *buffer.data.cast::<[u8; 4096]>()) + { + Ok(Some((received_address, PacketReceived::Complete(received_length)))) => { + *(*host).received_address.assume_init_mut() = Some(received_address); + received_length } Ok(Some((_, PacketReceived::Partial))) => { continue; @@ -2210,16 +2210,20 @@ unsafe fn enet_protocol_send_outgoing_commands( if let Some(compressor) = (*host).compressor.assume_init_mut() { let original_size: usize = ((*host).packet_size) .wrapping_sub(::core::mem::size_of::()); - let mut in_buffers = vec![]; + let mut in_buffers: [&[u8]; ENET_BUFFER_MAXIMUM as usize] = + std::array::from_fn(|_| { + from_raw_parts_or_empty::(std::ptr::null(), 0) + }); + #[allow(clippy::needless_range_loop)] for i in 0..((*host).buffer_count).wrapping_sub(1) { let buffer = ((*host).buffers).as_mut_ptr().add(1 + i); - in_buffers.push(super::from_raw_parts_or_empty( + in_buffers[i] = super::from_raw_parts_or_empty( (*buffer).data, (*buffer).data_length, - )); + ); } let compressed_size: usize = compressor.compress( - in_buffers, + &in_buffers[0..((*host).buffer_count).wrapping_sub(1)], original_size, super::from_raw_parts_or_empty_mut( ((*host).packet_data[1_i32 as usize]).as_mut_ptr(), @@ -2265,13 +2269,17 @@ unsafe fn enet_protocol_send_outgoing_commands( *fresh35 = (*fresh35 as u64) .wrapping_add(::core::mem::size_of::() as u64) as usize; - let mut in_buffers = vec![]; + let mut in_buffers: [&[u8]; ENET_BUFFER_MAXIMUM as usize] = + std::array::from_fn(|_| { + from_raw_parts_or_empty::(std::ptr::null(), 0) + }); + #[allow(clippy::needless_range_loop)] for i in 0..(*host).buffer_count { let buffer = ((*host).buffers).as_mut_ptr().add(i); - in_buffers.push(super::from_raw_parts_or_empty( + in_buffers[i] = super::from_raw_parts_or_empty( (*buffer).data, (*buffer).data_length, - )); + ); } checksum = checksum_fn(&in_buffers); copy_nonoverlapping( @@ -2288,6 +2296,7 @@ unsafe fn enet_protocol_send_outgoing_commands( } (*current_peer).last_send_time = (*host).service_time; let mut conglomerate_buffer = vec![]; + conglomerate_buffer.reserve_exact(ENET_BUFFER_MAXIMUM as usize); for buffer_index in 0..(*host).buffer_count { let buffer = &(*host).buffers[buffer_index]; conglomerate_buffer.extend_from_slice(super::from_raw_parts_or_empty( diff --git a/src/compressor.rs b/src/compressor.rs index 907668c..b028a9c 100644 --- a/src/compressor.rs +++ b/src/compressor.rs @@ -1,12 +1,14 @@ +use std::mem::zeroed; + use crate::{ - enet_range_coder_compress, enet_range_coder_create, enet_range_coder_decompress, - enet_range_coder_destroy, ENetBuffer, ENetRangeCoder, + consts::ENET_BUFFER_MAXIMUM, enet_range_coder_compress, enet_range_coder_create, + enet_range_coder_decompress, enet_range_coder_destroy, ENetBuffer, ENetRangeCoder, }; /// An interface for compressing ENet packets. pub trait Compressor { /// Compress the incoming buffers. - fn compress(&mut self, in_buffers: Vec<&[u8]>, in_limit: usize, out: &mut [u8]) -> usize; + fn compress(&mut self, in_buffers: &[&[u8]], in_limit: usize, out: &mut [u8]) -> usize; /// Decompress the buffer. fn decompress(&mut self, in_data: &[u8], out: &mut [u8]) -> usize; } @@ -32,19 +34,20 @@ impl Default for RangeCoder { } impl Compressor for RangeCoder { - fn compress(&mut self, in_buffers: Vec<&[u8]>, in_limit: usize, out: &mut [u8]) -> usize { + fn compress(&mut self, in_buffers: &[&[u8]], in_limit: usize, out: &mut [u8]) -> usize { unsafe { - let mut buffers = vec![]; - for in_buffer in in_buffers { - buffers.push(ENetBuffer { + let mut buffers: [ENetBuffer; ENET_BUFFER_MAXIMUM as usize] = + std::array::from_fn(|_| zeroed()); + for (i, in_buffer) in in_buffers.iter().enumerate() { + buffers[i] = ENetBuffer { data: in_buffer.as_ptr().cast_mut(), data_length: in_buffer.len(), - }); + }; } enet_range_coder_compress( self.0.cast(), buffers.as_ptr(), - buffers.len(), + in_buffers.len(), in_limit, out.as_mut_ptr(), out.len(), diff --git a/src/consts.rs b/src/consts.rs index 254fd0c..1930ee1 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -5,8 +5,8 @@ pub const ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT: u32 = 1; pub const ENET_PROTOCOL_MAXIMUM_WINDOW_SIZE: u32 = 65536; pub const ENET_PROTOCOL_MINIMUM_WINDOW_SIZE: u32 = 4096; pub const ENET_PROTOCOL_MAXIMUM_PACKET_COMMANDS: u32 = 32; -pub const ENET_PROTOCOL_MAXIMUM_MTU: u32 = 4096; -pub const ENET_PROTOCOL_MINIMUM_MTU: u32 = 576; +pub const ENET_PROTOCOL_MAXIMUM_MTU: usize = 4096; +pub const ENET_PROTOCOL_MINIMUM_MTU: usize = 576; pub const ENET_PEER_FREE_RELIABLE_WINDOWS: u32 = 8; pub const ENET_PEER_RELIABLE_WINDOW_SIZE: u32 = 4096; pub const ENET_PEER_RELIABLE_WINDOWS: u32 = 16; @@ -33,3 +33,5 @@ pub const ENET_HOST_DEFAULT_MTU: u32 = 1392; pub const ENET_HOST_BANDWIDTH_THROTTLE_INTERVAL: u32 = 1000; pub const ENET_HOST_SEND_BUFFER_SIZE: u32 = 262144; pub const ENET_HOST_RECEIVE_BUFFER_SIZE: u32 = 262144; + +pub const ENET_BUFFER_MAXIMUM: u32 = ENET_PROTOCOL_MAXIMUM_PACKET_COMMANDS * 2 + 1; diff --git a/src/host.rs b/src/host.rs index 2baa0ce..4435eed 100644 --- a/src/host.rs +++ b/src/host.rs @@ -107,6 +107,7 @@ impl Host { ) .map_err(|err| HostNewError::FailedToInitializeSocket(err))?; let mut peers = vec![]; + peers.reserve_exact((*host).peer_count); for peer_index in 0..(*host).peer_count { peers.push(Peer((*host).peers.add(peer_index))); } @@ -375,6 +376,19 @@ impl Host { Ok(()) } + /// The maximum transmission unit, or the maximum packet size that will be sent by this host. + #[must_use] + pub fn mtu(&self) -> u16 { + unsafe { (*self.host).mtu as u16 } + } + + /// Set the maximum transmission unit. See [`Host::mtu`]. + pub fn set_mtu(&self, mtu: u16) { + unsafe { + (*self.host).mtu = mtu as u32; + } + } + fn create_event<'a>(&'a mut self, event: &ENetEvent) -> Event<'a, S> { match event.type_0 { ENET_EVENT_TYPE_CONNECT => Event::Connect { diff --git a/src/peer.rs b/src/peer.rs index eaedaf9..6e02c40 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -174,6 +174,20 @@ impl Peer { unsafe { enet_peer_throttle_configure(self.0, interval, acceleration, deceleration) } } + /// The maximum transmission unit of this peer. See [`Host::mtu`](`crate::Host::mtu`). + #[must_use] + pub fn mtu(&self) -> u16 { + unsafe { (*self.0).mtu as u16 } + } + + /// Set the maximum transmission unit for this peer. See + /// [`Host::set_mtu`](`crate::Host::set_mtu`). + pub fn set_mtu(&self, mtu: u16) { + unsafe { + (*self.0).mtu = mtu as u32; + } + } + /// Get the current state of the peer. #[must_use] pub fn state(&self) -> PeerState { diff --git a/src/read_write.rs b/src/read_write.rs index 0e22f12..84ca09a 100644 --- a/src/read_write.rs +++ b/src/read_write.rs @@ -1,6 +1,9 @@ -use std::collections::VecDeque; +use std::{ + collections::VecDeque, + io::{copy, Cursor}, +}; -use crate::{Address, PacketReceived, Socket, SocketOptions}; +use crate::{Address, PacketReceived, Socket, SocketOptions, MTU_MAX}; /// Provides a Read/Write interface for use with [`Host`](`crate::Host`). /// @@ -81,11 +84,18 @@ impl Socket Ok(buffer.len()) } - fn receive(&mut self, _mtu: usize) -> Result, E> { + fn receive(&mut self, buffer: &mut [u8; MTU_MAX]) -> Result, E> { if let Some(error) = self.error.take() { Err(error) - } else if let Some((address, buffer)) = self.inbound.pop_front() { - Ok(Some((address, PacketReceived::Complete(buffer)))) + } else if let Some((address, inbound)) = self.inbound.pop_front() { + let bytes = inbound.len(); + if bytes <= MTU_MAX { + copy(&mut Cursor::new(inbound), &mut Cursor::new(&mut buffer[..])) + .expect("Buffer copy should not fail."); + Ok(Some((address, PacketReceived::Complete(bytes)))) + } else { + Ok(None) + } } else { Ok(None) } diff --git a/src/socket.rs b/src/socket.rs index e930aa2..10d5788 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -3,7 +3,17 @@ use std::{ net::{SocketAddr, UdpSocket}, }; -use crate::Address; +use crate::{consts::ENET_PROTOCOL_MAXIMUM_MTU, Address}; + +/// The maximum amount of bytes ENet will ever send or receive. Useful for allocating buffers when +/// sending and receiving. +/// +/// The actual MTU used by hosts and peers is typically much lower than this maximum and can be +/// changed with [`Host::set_mtu`](`crate::Host::set_mtu`) and +/// [`Peer::set_mtu`](`crate::Peer::set_mtu`). +/// +/// A shorter an easier to remember equivalent to [`ENET_PROTOCOL_MAXIMUM_MTU`]. +pub const MTU_MAX: usize = ENET_PROTOCOL_MAXIMUM_MTU; /// Socket options provided by ENet and passed to [`Socket::init`] when creating a /// [`Host`](`crate::Host`). @@ -33,18 +43,26 @@ pub trait Socket: Sized { /// Initialize the socket with options passed down by ENet. /// - /// Called in [`Host::new`](`crate::Host::new`). + /// Called in [`Host::new`](`crate::Host::new`). If this function returns an error, it is + /// bubbled up through [`Host::new`](`crate::Host::new`). fn init(&mut self, _socket_options: SocketOptions) -> Result<(), Self::Error> { Ok(()) } /// Try to send data. Should return the number of bytes successfully sent, or an error. fn send(&mut self, address: Self::Address, buffer: &[u8]) -> Result; - /// Try to receive data. May return an error, or optionally, a data packet. + + /// Try to receive data from the socket into a buffer of size [`MTU_MAX`]. /// - /// Data packets are wrapped in [`PacketReceived`]. See its docs for more info. + /// A received packet should be written into the provided buffer. If a packet is received that + /// is larger than [`MTU_MAX`], it should simply be discarded. ENet will never send a packet + /// that is larger than this maximum, so if one is received, it was not sent by ENet. + /// + /// The return value should be `Ok(None)` if no packet was received. If a packet was received, + /// the address of the peer socket, as well as the amount of bytes received should be returned. + /// Packets received may be complete or partial. See [`PacketReceived`] for more info. fn receive( &mut self, - mtu: usize, + buffer: &mut [u8; MTU_MAX], ) -> Result, Self::Error>; } @@ -60,8 +78,8 @@ pub trait Socket: Sized { /// [`PacketReceived::Complete`]. #[derive(Debug)] pub enum PacketReceived { - /// A complete packet was received. - Complete(Vec), + /// A complete packet was received. The inner value is the size of the packet in bytes. + Complete(usize), /// A partial packet was received. Partial, } @@ -85,15 +103,14 @@ impl Socket for UdpSocket { } } - fn receive(&mut self, mtu: usize) -> Result, io::Error> { - let mut buffer = vec![0; mtu]; - match self.recv_from(&mut buffer) { + fn receive( + &mut self, + buffer: &mut [u8; MTU_MAX], + ) -> Result, io::Error> { + match self.recv_from(buffer) { Ok((recv_length, recv_addr)) => { // TODO: MSG_TRUNC? (not supported by rust stdlib) - Ok(Some(( - recv_addr, - PacketReceived::Complete(Vec::from(&buffer[0..recv_length])), - ))) + Ok(Some((recv_addr, PacketReceived::Complete(recv_length)))) } Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(None), Err(err) => Err(err), diff --git a/src/test/network.rs b/src/test/network.rs index 04faa7a..bf9b5d3 100644 --- a/src/test/network.rs +++ b/src/test/network.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, convert::Infallible, + io::{copy, Cursor}, ops::{Deref, DerefMut}, sync::{mpsc, Arc, RwLock}, time::Duration, @@ -9,7 +10,7 @@ use std::{ use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; -use crate as enet; +use crate::{self as enet, MTU_MAX}; pub struct Socket { sender: mpsc::Sender<(usize, Vec)>, @@ -60,10 +61,17 @@ impl enet::Socket for Socket { fn receive( &mut self, - _mtu: usize, + buffer: &mut [u8; MTU_MAX], ) -> Result, Self::Error> { if let Some((address, data)) = Socket::receive(self) { - Ok(Some((address, enet::PacketReceived::Complete(data)))) + let data_length = data.len(); + if data_length <= MTU_MAX { + copy(&mut Cursor::new(data), &mut Cursor::new(&mut buffer[..])) + .expect("Buffer copy should not fail."); + Ok(Some((address, enet::PacketReceived::Complete(data_length)))) + } else { + Ok(None) + } } else { Ok(None) }