Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface refactor #7

Merged
merged 6 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 64 additions & 225 deletions src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,60 @@
use std::{
collections::HashMap,
net::{IpAddr, SocketAddr},
str::FromStr,
};

use super::cerr;

pub struct InterfaceIterator {
pub fn interfaces() -> std::io::Result<HashMap<InterfaceName, InterfaceData>> {
let mut elements = HashMap::default();

for data in InterfaceIterator::new()? {
let current: &mut InterfaceData = elements.entry(data.name).or_default();

current.socket_addrs.extend(data.socket_addr);
assert!(!(current.mac.is_some() && data.mac.is_some()));
current.mac = current.mac.or(data.mac);
}

Ok(elements)
}

#[derive(Default, Debug)]
pub struct InterfaceData {
socket_addrs: Vec<SocketAddr>,
mac: Option<[u8; 6]>,
}

impl InterfaceData {
pub fn has_ip_addr(&self, address: IpAddr) -> bool {
self.socket_addrs
.iter()
.any(|socket_addr| socket_addr.ip() == address)
}

pub fn mac(&self) -> Option<[u8; 6]> {
self.mac
}
}

struct InterfaceIterator {
base: *mut libc::ifaddrs,
next: *mut libc::ifaddrs,
}

impl InterfaceIterator {
pub fn new() -> std::io::Result<Self> {
let mut addrs = core::mem::MaybeUninit::<*mut libc::ifaddrs>::uninit();
let mut addrs: *mut libc::ifaddrs = std::ptr::null_mut();

unsafe {
cerr(libc::getifaddrs(addrs.as_mut_ptr()))?;
cerr(libc::getifaddrs(&mut addrs))?;

assert!(!addrs.is_null());

Ok(Self {
base: addrs.assume_init(),
next: addrs.assume_init(),
base: addrs,
next: addrs,
})
}
}
Expand All @@ -31,23 +66,14 @@ impl Drop for InterfaceIterator {
}
}

pub struct InterfaceData {
pub name: InterfaceName,
pub mac: Option<[u8; 6]>,
pub socket_addr: Option<SocketAddr>,
}

impl InterfaceData {
pub fn has_ip_addr(&self, address: IpAddr) -> bool {
match self.socket_addr {
None => false,
Some(socket_addr) => socket_addr.ip() == address,
}
}
struct InterfaceDataInternal {
name: InterfaceName,
mac: Option<[u8; 6]>,
socket_addr: Option<SocketAddr>,
}

impl Iterator for InterfaceIterator {
type Item = InterfaceData;
type Item = InterfaceDataInternal;

fn next(&mut self) -> Option<<Self as Iterator>::Item> {
let ifaddr = unsafe { self.next.as_ref() }?;
Expand Down Expand Up @@ -80,7 +106,7 @@ impl Iterator for InterfaceIterator {

let socket_addr = unsafe { sockaddr_to_socket_addr(ifaddr.ifa_addr) };

let data = InterfaceData {
let data = InterfaceDataInternal {
name,
mac,
socket_addr,
Expand All @@ -90,7 +116,7 @@ impl Iterator for InterfaceIterator {
}
}

#[derive(Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct InterfaceName {
bytes: [u8; libc::IFNAMSIZ],
}
Expand Down Expand Up @@ -127,7 +153,7 @@ impl InterfaceName {
}

pub fn from_socket_addr(local_addr: SocketAddr) -> std::io::Result<Option<Self>> {
let matches_inferface = |interface: &InterfaceData| match interface.socket_addr {
let matches_inferface = |interface: &InterfaceDataInternal| match interface.socket_addr {
None => false,
Some(address) => address.ip() == local_addr.ip(),
};
Expand All @@ -138,14 +164,14 @@ impl InterfaceName {
}
}

pub fn get_index(self) -> Option<libc::c_uint> {
// Temporary implementation until great refactor
InterfaceDescriptor {
interface_name: Some(self),
// doesn't matter
mode: LinuxNetworkMode::Ipv4,
pub fn get_index(&self) -> Option<libc::c_uint> {
// # SAFETY
//
// The pointer is valid and null-terminated
match unsafe { libc::if_nametoindex(self.as_cstr().as_ptr()) } {
0 => None,
n => Some(n),
}
.get_index()
}
}

Expand Down Expand Up @@ -197,12 +223,6 @@ impl<'de> serde::Deserialize<'de> for InterfaceName {
}
}

#[derive(Debug, Clone)]
pub struct InterfaceDescriptor {
pub interface_name: Option<InterfaceName>,
pub mode: LinuxNetworkMode,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinuxNetworkMode {
Ipv4,
Expand All @@ -218,94 +238,6 @@ impl LinuxNetworkMode {
}
}

fn cannot_iterate_interfaces() -> std::io::Error {
let msg = "Could not iterate over interfaces";
std::io::Error::new(std::io::ErrorKind::Other, msg)
}

fn interface_does_not_exist() -> std::io::Error {
let msg = "The specified interface does not exist";
std::io::Error::new(std::io::ErrorKind::Other, msg)
}

impl InterfaceDescriptor {
pub fn get_index(&self) -> Option<libc::c_uint> {
let name = self.interface_name.as_ref()?;

// # SAFETY
//
// The pointer is valid and null-terminated
match unsafe { libc::if_nametoindex(name.as_cstr().as_ptr()) } {
0 => None,
n => Some(n),
}
}

pub fn get_address(&self) -> std::io::Result<IpAddr> {
if let Some(name) = self.interface_name {
let interfaces = InterfaceIterator::new().map_err(|_| cannot_iterate_interfaces())?;

interfaces
.filter(|i| name == i.name)
.filter_map(|i| i.socket_addr)
.map(|socket_addr| socket_addr.ip())
.find(|ip| match self.mode {
LinuxNetworkMode::Ipv4 => ip.is_ipv4(),
LinuxNetworkMode::Ipv6 => ip.is_ipv6(),
})
.ok_or(interface_does_not_exist())
} else {
Ok(self.mode.unspecified_ip_addr())
}
}
}

impl FromStr for InterfaceDescriptor {
type Err = std::io::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut interfaces = match InterfaceIterator::new() {
Ok(a) => a,
Err(_) => return Err(cannot_iterate_interfaces()),
};

match std::net::IpAddr::from_str(s) {
Ok(addr) => {
if addr.is_unspecified() {
return Ok(InterfaceDescriptor {
interface_name: None,
mode: match addr {
IpAddr::V4(_) => LinuxNetworkMode::Ipv4,
IpAddr::V6(_) => LinuxNetworkMode::Ipv6,
},
});
}

interfaces
.find(|data| data.has_ip_addr(addr))
.map(|data| InterfaceDescriptor {
interface_name: Some(data.name),
mode: LinuxNetworkMode::Ipv4,
})
.ok_or(interface_does_not_exist())
}
Err(_) => {
if interfaces.any(|if_data| if_data.name.as_str() == s) {
// the interface name came straight from the OS, so it must be valid
let interface_name = InterfaceName::from_str(s).unwrap();

Ok(InterfaceDescriptor {
interface_name: Some(interface_name),
mode: LinuxNetworkMode::Ipv4,
})
} else {
Err(interface_does_not_exist())
}
}
}
}
}

/// Convert a libc::sockaddr to a rust std::net::SocketAddr
///
/// # Safety
Expand Down Expand Up @@ -371,7 +303,7 @@ pub fn sockaddr_storage_to_socket_addr(

#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use std::net::Ipv4Addr;

use super::*;

Expand Down Expand Up @@ -420,110 +352,17 @@ mod tests {
}

#[test]
fn test_interface_from_str() {
let interface = InterfaceDescriptor::from_str("0.0.0.0").unwrap();

assert!(matches!(interface.mode, LinuxNetworkMode::Ipv4));
assert!(interface.interface_name.is_none());

let interface = InterfaceDescriptor::from_str("::").unwrap();

assert!(matches!(interface.mode, LinuxNetworkMode::Ipv6));
assert!(interface.interface_name.is_none());

let interface = InterfaceDescriptor::from_str("lo").unwrap();

assert!(matches!(interface.mode, LinuxNetworkMode::Ipv4));
assert_eq!(interface.interface_name.unwrap(), InterfaceName::LOOPBACK);

let error = InterfaceDescriptor::from_str("xxx").unwrap_err();

assert_eq!(error.to_string(), interface_does_not_exist().to_string());
fn interface_index_ipv4() {
assert!(InterfaceName::LOOPBACK.get_index().is_some());
}

#[tokio::test]
async fn get_address_ipv4_invalid() {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::from_str("invalid").unwrap()),
mode: LinuxNetworkMode::Ipv4,
};

assert_eq!(
interface.get_address().unwrap_err().to_string(),
interface_does_not_exist().to_string()
);
}

#[tokio::test]
async fn get_address_ipv6_invalid() {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::from_str("invalid").unwrap()),
mode: LinuxNetworkMode::Ipv6,
};

assert_eq!(
interface.get_address().unwrap_err().to_string(),
interface_does_not_exist().to_string()
);
}

#[tokio::test]
async fn interface_index_ipv4() -> std::io::Result<()> {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::LOOPBACK),
mode: LinuxNetworkMode::Ipv4,
};

assert!(interface.get_index().is_some());

Ok(())
}

#[tokio::test]
async fn interface_index_ipv6() -> std::io::Result<()> {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::LOOPBACK),
mode: LinuxNetworkMode::Ipv6,
};

assert!(interface.get_index().is_some());

Ok(())
}

#[tokio::test]
async fn interface_index_invalid() -> std::io::Result<()> {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::INVALID),
mode: LinuxNetworkMode::Ipv4,
};

assert!(interface.get_index().is_none());

Ok(())
}

#[tokio::test]
async fn get_address_ipv4_valid() -> Result<(), Box<dyn std::error::Error>> {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::LOOPBACK),
mode: LinuxNetworkMode::Ipv4,
};

assert_eq!(interface.get_address()?, Ipv4Addr::LOCALHOST);

Ok(())
#[test]
fn interface_index_ipv6() {
assert!(InterfaceName::LOOPBACK.get_index().is_some());
}

#[tokio::test]
async fn get_address_ipv6_valid() -> Result<(), Box<dyn std::error::Error>> {
let interface = InterfaceDescriptor {
interface_name: Some(InterfaceName::LOOPBACK),
mode: LinuxNetworkMode::Ipv6,
};

assert_eq!(interface.get_address()?, Ipv6Addr::LOCALHOST);

Ok(())
#[test]
fn interface_index_invalid() {
assert!(InterfaceName::INVALID.get_index().is_none());
}
}
Loading
Loading