Skip to content

Commit

Permalink
reduce allocations and allow setting MTU
Browse files Browse the repository at this point in the history
  • Loading branch information
jabuwu committed May 18, 2024
1 parent 83f303f commit 033fb1f
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 69 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# UNRELEASED
- Compatible with ENet 1.3.18
- Refine some trait requirements and derives
- Reduce allocations introduced by Rust port ([#1](https://github.com/jabuwu/rusty_enet/issues/1))
- Adjust `Socket::receive` interface to one which takes a pre-allocated buffer
- Add `MTU_MAX` constant (an alias of `ENET_PROTOCOL_MAXIMUM_MTU`)
- Add functions:
- [`Host::mtu`]
- [`Host::set_mtu`]
- [`Peer::mtu`]
- [`Peer::set_mtu`]

# 0.1.0
- Initial release
6 changes: 3 additions & 3 deletions src/c/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ pub(crate) struct ENetHost<S: Socket> {
pub(crate) total_queued: u32,
pub(crate) packet_size: usize,
pub(crate) header_flags: u16,
pub(crate) commands: [ENetProtocol; 32],
pub(crate) commands: [ENetProtocol; ENET_PROTOCOL_MAXIMUM_PACKET_COMMANDS as usize],
pub(crate) command_count: usize,
pub(crate) buffers: [ENetBuffer; 65],
pub(crate) buffers: [ENetBuffer; ENET_BUFFER_MAXIMUM as usize],
pub(crate) buffer_count: usize,
pub(crate) checksum: MaybeUninit<Option<Box<dyn Fn(&[&[u8]]) -> u32>>>,
pub(crate) time: MaybeUninit<Box<dyn Fn() -> Duration>>,
pub(crate) compressor: MaybeUninit<Option<Box<dyn Compressor>>>,
pub(crate) packet_data: [[u8; 4096]; 2],
pub(crate) packet_data: [[u8; ENET_PROTOCOL_MAXIMUM_MTU]; 2],
pub(crate) received_address: MaybeUninit<Option<S::Address>>,
pub(crate) received_data: *mut u8,
pub(crate) received_data_length: usize,
Expand Down
75 changes: 42 additions & 33 deletions src/c/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@ use std::ptr::{copy_nonoverlapping, write_bytes};

use crate::{
consts::{
ENET_HOST_BANDWIDTH_THROTTLE_INTERVAL, ENET_PEER_FREE_RELIABLE_WINDOWS,
ENET_PEER_FREE_UNSEQUENCED_WINDOWS, ENET_PEER_PACKET_LOSS_INTERVAL,
ENET_PEER_PACKET_LOSS_SCALE, ENET_PEER_PACKET_THROTTLE_COUNTER,
ENET_PEER_PACKET_THROTTLE_SCALE, ENET_PEER_RELIABLE_WINDOWS,
ENET_PEER_RELIABLE_WINDOW_SIZE, ENET_PEER_UNSEQUENCED_WINDOW_SIZE,
ENET_PEER_WINDOW_SIZE_SCALE, ENET_PROTOCOL_MAXIMUM_CHANNEL_COUNT,
ENET_PROTOCOL_MAXIMUM_FRAGMENT_COUNT, ENET_PROTOCOL_MAXIMUM_MTU,
ENET_PROTOCOL_MAXIMUM_PEER_ID, ENET_PROTOCOL_MAXIMUM_WINDOW_SIZE,
ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT, ENET_PROTOCOL_MINIMUM_MTU,
ENET_PROTOCOL_MINIMUM_WINDOW_SIZE,
ENET_BUFFER_MAXIMUM, ENET_HOST_BANDWIDTH_THROTTLE_INTERVAL,
ENET_PEER_FREE_RELIABLE_WINDOWS, ENET_PEER_FREE_UNSEQUENCED_WINDOWS,
ENET_PEER_PACKET_LOSS_INTERVAL, ENET_PEER_PACKET_LOSS_SCALE,
ENET_PEER_PACKET_THROTTLE_COUNTER, ENET_PEER_PACKET_THROTTLE_SCALE,
ENET_PEER_RELIABLE_WINDOWS, ENET_PEER_RELIABLE_WINDOW_SIZE,
ENET_PEER_UNSEQUENCED_WINDOW_SIZE, ENET_PEER_WINDOW_SIZE_SCALE,
ENET_PROTOCOL_MAXIMUM_CHANNEL_COUNT, ENET_PROTOCOL_MAXIMUM_FRAGMENT_COUNT,
ENET_PROTOCOL_MAXIMUM_MTU, ENET_PROTOCOL_MAXIMUM_PEER_ID,
ENET_PROTOCOL_MAXIMUM_WINDOW_SIZE, ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT,
ENET_PROTOCOL_MINIMUM_MTU, ENET_PROTOCOL_MINIMUM_WINDOW_SIZE,
},
enet_free, enet_host_bandwidth_throttle, enet_list_clear, enet_list_insert, enet_list_remove,
enet_malloc, enet_packet_destroy, enet_peer_disconnect,
enet_peer_dispatch_incoming_reliable_commands, enet_peer_dispatch_incoming_unreliable_commands,
enet_peer_has_outgoing_commands, enet_peer_on_connect, enet_peer_on_disconnect, enet_peer_ping,
enet_peer_queue_acknowledgement, enet_peer_queue_incoming_command,
enet_peer_queue_outgoing_command, enet_peer_receive, enet_peer_reset, enet_peer_reset_queues,
enet_peer_throttle, enet_time_get, Address, ENetAcknowledgement, ENetBuffer, ENetChannel,
ENetEvent, ENetHost, ENetIncomingCommand, ENetList, ENetListIterator, ENetListNode,
ENetOutgoingCommand, ENetPeer, ENetPeerState, PacketReceived, Socket, ENET_EVENT_TYPE_CONNECT,
ENET_EVENT_TYPE_DISCONNECT, ENET_EVENT_TYPE_NONE, ENET_EVENT_TYPE_RECEIVE,
ENET_PACKET_FLAG_RELIABLE, ENET_PACKET_FLAG_SENT, ENET_PACKET_FLAG_UNRELIABLE_FRAGMENT,
ENET_PACKET_FLAG_UNSEQUENCED, ENET_PEER_FLAG_CONTINUE_SENDING, ENET_PEER_FLAG_NEEDS_DISPATCH,
enet_peer_throttle, enet_time_get, from_raw_parts_or_empty, Address, ENetAcknowledgement,
ENetBuffer, ENetChannel, ENetEvent, ENetHost, ENetIncomingCommand, ENetList, ENetListIterator,
ENetListNode, ENetOutgoingCommand, ENetPeer, ENetPeerState, PacketReceived, Socket,
ENET_EVENT_TYPE_CONNECT, ENET_EVENT_TYPE_DISCONNECT, ENET_EVENT_TYPE_NONE,
ENET_EVENT_TYPE_RECEIVE, ENET_PACKET_FLAG_RELIABLE, ENET_PACKET_FLAG_SENT,
ENET_PACKET_FLAG_UNRELIABLE_FRAGMENT, ENET_PACKET_FLAG_UNSEQUENCED,
ENET_PEER_FLAG_CONTINUE_SENDING, ENET_PEER_FLAG_NEEDS_DISPATCH,
ENET_PEER_STATE_ACKNOWLEDGING_CONNECT, ENET_PEER_STATE_ACKNOWLEDGING_DISCONNECT,
ENET_PEER_STATE_CONNECTED, ENET_PEER_STATE_CONNECTING, ENET_PEER_STATE_CONNECTION_PENDING,
ENET_PEER_STATE_CONNECTION_SUCCEEDED, ENET_PEER_STATE_DISCONNECTED,
Expand Down Expand Up @@ -1657,16 +1658,15 @@ unsafe fn enet_protocol_receive_incoming_commands<S: Socket>(
data_length: 0,
};
buffer.data = ((*host).packet_data[0_i32 as usize]).as_mut_ptr();
buffer.data_length = ::core::mem::size_of::<[u8; ENET_PROTOCOL_MAXIMUM_MTU as usize]>();
let received_length = match (*host).socket.assume_init_mut().receive(buffer.data_length) {
Ok(Some((received_address, PacketReceived::Complete(received_data)))) => {
if received_data.len() <= ENET_PROTOCOL_MAXIMUM_MTU as usize {
*(*host).received_address.assume_init_mut() = Some(received_address);
copy_nonoverlapping(received_data.as_ptr(), buffer.data, received_data.len());
received_data.len() as i32
} else {
continue;
}
buffer.data_length = ::core::mem::size_of::<[u8; ENET_PROTOCOL_MAXIMUM_MTU]>();
let received_length = match (*host)
.socket
.assume_init_mut()
.receive(&mut *buffer.data.cast::<[u8; 4096]>())
{
Ok(Some((received_address, PacketReceived::Complete(received_length)))) => {
*(*host).received_address.assume_init_mut() = Some(received_address);
received_length
}
Ok(Some((_, PacketReceived::Partial))) => {
continue;
Expand Down Expand Up @@ -2210,16 +2210,20 @@ unsafe fn enet_protocol_send_outgoing_commands<S: Socket>(
if let Some(compressor) = (*host).compressor.assume_init_mut() {
let original_size: usize = ((*host).packet_size)
.wrapping_sub(::core::mem::size_of::<ENetProtocolHeader>());
let mut in_buffers = vec![];
let mut in_buffers: [&[u8]; ENET_BUFFER_MAXIMUM as usize] =
std::array::from_fn(|_| {
from_raw_parts_or_empty::<u8>(std::ptr::null(), 0)
});
#[allow(clippy::needless_range_loop)]
for i in 0..((*host).buffer_count).wrapping_sub(1) {
let buffer = ((*host).buffers).as_mut_ptr().add(1 + i);
in_buffers.push(super::from_raw_parts_or_empty(
in_buffers[i] = super::from_raw_parts_or_empty(
(*buffer).data,
(*buffer).data_length,
));
);
}
let compressed_size: usize = compressor.compress(
in_buffers,
&in_buffers[0..((*host).buffer_count).wrapping_sub(1)],
original_size,
super::from_raw_parts_or_empty_mut(
((*host).packet_data[1_i32 as usize]).as_mut_ptr(),
Expand Down Expand Up @@ -2265,13 +2269,17 @@ unsafe fn enet_protocol_send_outgoing_commands<S: Socket>(
*fresh35 = (*fresh35 as u64)
.wrapping_add(::core::mem::size_of::<u32>() as u64)
as usize;
let mut in_buffers = vec![];
let mut in_buffers: [&[u8]; ENET_BUFFER_MAXIMUM as usize] =
std::array::from_fn(|_| {
from_raw_parts_or_empty::<u8>(std::ptr::null(), 0)
});
#[allow(clippy::needless_range_loop)]
for i in 0..(*host).buffer_count {
let buffer = ((*host).buffers).as_mut_ptr().add(i);
in_buffers.push(super::from_raw_parts_or_empty(
in_buffers[i] = super::from_raw_parts_or_empty(
(*buffer).data,
(*buffer).data_length,
));
);
}
checksum = checksum_fn(&in_buffers);
copy_nonoverlapping(
Expand All @@ -2288,6 +2296,7 @@ unsafe fn enet_protocol_send_outgoing_commands<S: Socket>(
}
(*current_peer).last_send_time = (*host).service_time;
let mut conglomerate_buffer = vec![];
conglomerate_buffer.reserve_exact(ENET_BUFFER_MAXIMUM as usize);
for buffer_index in 0..(*host).buffer_count {
let buffer = &(*host).buffers[buffer_index];
conglomerate_buffer.extend_from_slice(super::from_raw_parts_or_empty(
Expand Down
21 changes: 12 additions & 9 deletions src/compressor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::mem::zeroed;

use crate::{
enet_range_coder_compress, enet_range_coder_create, enet_range_coder_decompress,
enet_range_coder_destroy, ENetBuffer, ENetRangeCoder,
consts::ENET_BUFFER_MAXIMUM, enet_range_coder_compress, enet_range_coder_create,
enet_range_coder_decompress, enet_range_coder_destroy, ENetBuffer, ENetRangeCoder,
};

/// An interface for compressing ENet packets.
pub trait Compressor {
/// Compress the incoming buffers.
fn compress(&mut self, in_buffers: Vec<&[u8]>, in_limit: usize, out: &mut [u8]) -> usize;
fn compress(&mut self, in_buffers: &[&[u8]], in_limit: usize, out: &mut [u8]) -> usize;
/// Decompress the buffer.
fn decompress(&mut self, in_data: &[u8], out: &mut [u8]) -> usize;
}
Expand All @@ -32,19 +34,20 @@ impl Default for RangeCoder {
}

impl Compressor for RangeCoder {
fn compress(&mut self, in_buffers: Vec<&[u8]>, in_limit: usize, out: &mut [u8]) -> usize {
fn compress(&mut self, in_buffers: &[&[u8]], in_limit: usize, out: &mut [u8]) -> usize {
unsafe {
let mut buffers = vec![];
for in_buffer in in_buffers {
buffers.push(ENetBuffer {
let mut buffers: [ENetBuffer; ENET_BUFFER_MAXIMUM as usize] =
std::array::from_fn(|_| zeroed());
for (i, in_buffer) in in_buffers.iter().enumerate() {
buffers[i] = ENetBuffer {
data: in_buffer.as_ptr().cast_mut(),
data_length: in_buffer.len(),
});
};
}
enet_range_coder_compress(
self.0.cast(),
buffers.as_ptr(),
buffers.len(),
in_buffers.len(),
in_limit,
out.as_mut_ptr(),
out.len(),
Expand Down
6 changes: 4 additions & 2 deletions src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ pub const ENET_PROTOCOL_MINIMUM_CHANNEL_COUNT: u32 = 1;
pub const ENET_PROTOCOL_MAXIMUM_WINDOW_SIZE: u32 = 65536;
pub const ENET_PROTOCOL_MINIMUM_WINDOW_SIZE: u32 = 4096;
pub const ENET_PROTOCOL_MAXIMUM_PACKET_COMMANDS: u32 = 32;
pub const ENET_PROTOCOL_MAXIMUM_MTU: u32 = 4096;
pub const ENET_PROTOCOL_MINIMUM_MTU: u32 = 576;
pub const ENET_PROTOCOL_MAXIMUM_MTU: usize = 4096;
pub const ENET_PROTOCOL_MINIMUM_MTU: usize = 576;
pub const ENET_PEER_FREE_RELIABLE_WINDOWS: u32 = 8;
pub const ENET_PEER_RELIABLE_WINDOW_SIZE: u32 = 4096;
pub const ENET_PEER_RELIABLE_WINDOWS: u32 = 16;
Expand All @@ -33,3 +33,5 @@ pub const ENET_HOST_DEFAULT_MTU: u32 = 1392;
pub const ENET_HOST_BANDWIDTH_THROTTLE_INTERVAL: u32 = 1000;
pub const ENET_HOST_SEND_BUFFER_SIZE: u32 = 262144;
pub const ENET_HOST_RECEIVE_BUFFER_SIZE: u32 = 262144;

pub const ENET_BUFFER_MAXIMUM: u32 = ENET_PROTOCOL_MAXIMUM_PACKET_COMMANDS * 2 + 1;
14 changes: 14 additions & 0 deletions src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ impl<S: Socket> Host<S> {
)
.map_err(|err| HostNewError::FailedToInitializeSocket(err))?;
let mut peers = vec![];
peers.reserve_exact((*host).peer_count);
for peer_index in 0..(*host).peer_count {
peers.push(Peer((*host).peers.add(peer_index)));
}
Expand Down Expand Up @@ -375,6 +376,19 @@ impl<S: Socket> Host<S> {
Ok(())
}

/// The maximum transmission unit, or the maximum packet size that will be sent by this host.
#[must_use]
pub fn mtu(&self) -> u16 {
unsafe { (*self.host).mtu as u16 }
}

/// Set the maximum transmission unit. See [`Host::mtu`].
pub fn set_mtu(&self, mtu: u16) {
unsafe {
(*self.host).mtu = mtu as u32;
}
}

fn create_event<'a>(&'a mut self, event: &ENetEvent<S>) -> Event<'a, S> {
match event.type_0 {
ENET_EVENT_TYPE_CONNECT => Event::Connect {
Expand Down
14 changes: 14 additions & 0 deletions src/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ impl<S: Socket> Peer<S> {
unsafe { enet_peer_throttle_configure(self.0, interval, acceleration, deceleration) }
}

/// The maximum transmission unit of this peer. See [`Host::mtu`](`crate::Host::mtu`).
#[must_use]
pub fn mtu(&self) -> u16 {
unsafe { (*self.0).mtu as u16 }
}

/// Set the maximum transmission unit for this peer. See
/// [`Host::set_mtu`](`crate::Host::set_mtu`).
pub fn set_mtu(&self, mtu: u16) {
unsafe {
(*self.0).mtu = mtu as u32;
}
}

/// Get the current state of the peer.
#[must_use]
pub fn state(&self) -> PeerState {
Expand Down
20 changes: 15 additions & 5 deletions src/read_write.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::collections::VecDeque;
use std::{
collections::VecDeque,
io::{copy, Cursor},
};

use crate::{Address, PacketReceived, Socket, SocketOptions};
use crate::{Address, PacketReceived, Socket, SocketOptions, MTU_MAX};

/// Provides a Read/Write interface for use with [`Host`](`crate::Host`).
///
Expand Down Expand Up @@ -81,11 +84,18 @@ impl<A: Address + 'static, E: std::error::Error + Send + Sync + 'static> Socket
Ok(buffer.len())
}

fn receive(&mut self, _mtu: usize) -> Result<Option<(A, PacketReceived)>, E> {
fn receive(&mut self, buffer: &mut [u8; MTU_MAX]) -> Result<Option<(A, PacketReceived)>, E> {
if let Some(error) = self.error.take() {
Err(error)
} else if let Some((address, buffer)) = self.inbound.pop_front() {
Ok(Some((address, PacketReceived::Complete(buffer))))
} else if let Some((address, inbound)) = self.inbound.pop_front() {
let bytes = inbound.len();
if bytes <= MTU_MAX {
copy(&mut Cursor::new(inbound), &mut Cursor::new(&mut buffer[..]))
.expect("Buffer copy should not fail.");
Ok(Some((address, PacketReceived::Complete(bytes))))
} else {
Ok(None)
}
} else {
Ok(None)
}
Expand Down
45 changes: 31 additions & 14 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@ use std::{
net::{SocketAddr, UdpSocket},
};

use crate::Address;
use crate::{consts::ENET_PROTOCOL_MAXIMUM_MTU, Address};

/// The maximum amount of bytes ENet will ever send or receive. Useful for allocating buffers when
/// sending and receiving.
///
/// The actual MTU used by hosts and peers is typically much lower than this maximum and can be
/// changed with [`Host::set_mtu`](`crate::Host::set_mtu`) and
/// [`Peer::set_mtu`](`crate::Peer::set_mtu`).
///
/// A shorter an easier to remember equivalent to [`ENET_PROTOCOL_MAXIMUM_MTU`].
pub const MTU_MAX: usize = ENET_PROTOCOL_MAXIMUM_MTU;

/// Socket options provided by ENet and passed to [`Socket::init`] when creating a
/// [`Host`](`crate::Host`).
Expand Down Expand Up @@ -33,18 +43,26 @@ pub trait Socket: Sized {

/// Initialize the socket with options passed down by ENet.
///
/// Called in [`Host::new`](`crate::Host::new`).
/// Called in [`Host::new`](`crate::Host::new`). If this function returns an error, it is
/// bubbled up through [`Host::new`](`crate::Host::new`).
fn init(&mut self, _socket_options: SocketOptions) -> Result<(), Self::Error> {
Ok(())
}
/// Try to send data. Should return the number of bytes successfully sent, or an error.
fn send(&mut self, address: Self::Address, buffer: &[u8]) -> Result<usize, Self::Error>;
/// Try to receive data. May return an error, or optionally, a data packet.

/// Try to receive data from the socket into a buffer of size [`MTU_MAX`].
///
/// Data packets are wrapped in [`PacketReceived`]. See its docs for more info.
/// A received packet should be written into the provided buffer. If a packet is received that
/// is larger than [`MTU_MAX`], it should simply be discarded. ENet will never send a packet
/// that is larger than this maximum, so if one is received, it was not sent by ENet.
///
/// The return value should be `Ok(None)` if no packet was received. If a packet was received,
/// the address of the peer socket, as well as the amount of bytes received should be returned.
/// Packets received may be complete or partial. See [`PacketReceived`] for more info.
fn receive(
&mut self,
mtu: usize,
buffer: &mut [u8; MTU_MAX],
) -> Result<Option<(Self::Address, PacketReceived)>, Self::Error>;
}

Expand All @@ -60,8 +78,8 @@ pub trait Socket: Sized {
/// [`PacketReceived::Complete`].
#[derive(Debug)]
pub enum PacketReceived {
/// A complete packet was received.
Complete(Vec<u8>),
/// A complete packet was received. The inner value is the size of the packet in bytes.
Complete(usize),
/// A partial packet was received.
Partial,
}
Expand All @@ -85,15 +103,14 @@ impl Socket for UdpSocket {
}
}

fn receive(&mut self, mtu: usize) -> Result<Option<(SocketAddr, PacketReceived)>, io::Error> {
let mut buffer = vec![0; mtu];
match self.recv_from(&mut buffer) {
fn receive(
&mut self,
buffer: &mut [u8; MTU_MAX],
) -> Result<Option<(SocketAddr, PacketReceived)>, io::Error> {
match self.recv_from(buffer) {
Ok((recv_length, recv_addr)) => {
// TODO: MSG_TRUNC? (not supported by rust stdlib)
Ok(Some((
recv_addr,
PacketReceived::Complete(Vec::from(&buffer[0..recv_length])),
)))
Ok(Some((recv_addr, PacketReceived::Complete(recv_length))))
}
Err(err) if err.kind() == ErrorKind::WouldBlock => Ok(None),
Err(err) => Err(err),
Expand Down
Loading

0 comments on commit 033fb1f

Please sign in to comment.