From ce6b1044ae7f355d3281930a1a48264e4cc52948 Mon Sep 17 00:00:00 2001 From: Michael Martin Date: Wed, 20 Dec 2023 17:04:56 -0800 Subject: [PATCH] chore(*): refactor and expand test coverage --- src/cli.rs | 509 +++++++++++++++++++++++++++++++++++++-------------- src/lua.rs | 2 +- src/main.rs | 2 +- src/nginx.rs | 52 +++--- src/run.rs | 2 +- src/types.rs | 169 ++++++++--------- src/util.rs | 94 ++++++++-- 7 files changed, 559 insertions(+), 271 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 42ad55a..9337980 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -234,7 +234,7 @@ fn include_file(section: &str, fname: String) -> Result { Ok(path.to_str().expect("uh oh").to_string()) } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum Action { Help, Version(Option), @@ -352,7 +352,7 @@ impl Action { } } -#[derive(Default, Debug)] +#[derive(Default, Debug, PartialEq, Eq)] pub struct UserArgs { pub(crate) inline_lua: Vec, pub(crate) lua_file: Option, @@ -396,203 +396,432 @@ fn discover_system_nameservers() -> Vec { ns } -pub fn init(args: Vec) -> Result { - let mut user = UserArgs { - arg_c: args.len(), - worker_connections: 64, - ..Default::default() - }; +impl Action { + pub fn try_from(args: T) -> Result + where + T: IntoIterator, + { + let mut args: VecDeque = args.into_iter().collect(); + + let mut user = UserArgs { + arg_c: args.len(), + worker_connections: 64, + ..Default::default() + }; - let runner = &mut user.runner; + let runner = &mut user.runner; - let mut show_version = false; + let mut show_version = false; - let mut args: VecDeque = args.into(); - user.arg_0 = args.pop_front().expect("empgy ARGV[0]"); + user.arg_0 = args.pop_front().ok_or(ArgError::EmptyArgv0)?; - while let Some(arg) = args.pop_front() { - let mut optarg: Option = None; - let mut combined_opt_arg = false; - let mut arg = arg; + while let Some(arg) = args.pop_front() { + let mut optarg: Option = None; + let mut combined_opt_arg = false; + let mut arg = arg; - if arg.is_opt() { - if let Some((a, o)) = arg.parse_opt_eq() { - arg = a; - optarg = Some(o); - combined_opt_arg = true; - } else { - optarg = args.pop_front(); + if arg.is_opt() { + if let Some((a, o)) = arg.parse_opt_eq() { + arg = a; + optarg = Some(o); + combined_opt_arg = true; + } else { + optarg = args.pop_front(); + } } - } - let mut end_of_args = false; + let mut end_of_args = false; - let optarg = &mut optarg; + let optarg = &mut optarg; - match arg.as_str() { - "-v" | "-V" | "--version" => { - show_version = true; - if user.nginx_bin.is_some() { - return Ok(Action::Version(user.nginx_bin)); + match arg.as_str() { + "-v" | "-V" | "--version" => { + show_version = true; + if user.nginx_bin.is_some() { + return Ok(Action::Version(user.nginx_bin)); + } } - } - "-h" | "--help" => { - return Ok(Action::Help); - } + "-h" | "--help" => { + return Ok(Action::Help); + } - "--gdb" => runner.update(Runner::Gdb(None))?, + "--gdb" => runner.update(Runner::Gdb(None))?, - "--gdb-opts" => { - let opts = arg.get_arg(optarg)?; - runner.update(Runner::Gdb(Some(opts)))?; - } + "--gdb-opts" => { + let opts = arg.get_arg(optarg)?; + runner.update(Runner::Gdb(Some(opts)))?; + } - "--rr" => runner.update(Runner::RR)?, + "--rr" => runner.update(Runner::RR)?, - "--stap" => runner.update(Runner::Stap(None))?, + "--stap" => runner.update(Runner::Stap(None))?, - "--stap-opts" => { - let opts = arg.get_arg(optarg)?; - runner.update(Runner::Stap(Some(opts)))?; - } + "--stap-opts" => { + let opts = arg.get_arg(optarg)?; + runner.update(Runner::Stap(Some(opts)))?; + } - "--user-runner" => { - let opts = arg.get_arg(optarg)?; - runner.update(Runner::User(opts))?; - } + "--user-runner" => { + let opts = arg.get_arg(optarg)?; + runner.update(Runner::User(opts))?; + } - "--valgrind" => runner.update(Runner::Valgrind(None))?, + "--valgrind" => runner.update(Runner::Valgrind(None))?, - "--valgrind-opts" => { - let opts = arg.get_arg(optarg)?; - runner.update(Runner::Valgrind(Some(opts)))?; - } + "--valgrind-opts" => { + let opts = arg.get_arg(optarg)?; + runner.update(Runner::Valgrind(Some(opts)))?; + } - "-c" => { - user.worker_connections = arg.parse_to(optarg)?; - } + "-c" => { + user.worker_connections = arg.parse_to(optarg)?; + } - "--http-conf" => { - arg.push_to(&mut user.http_conf, optarg)?; - } + "--http-conf" => { + arg.push_to(&mut user.http_conf, optarg)?; + } - "--http-include" => { - let value = arg.get_arg(optarg)?; - user.http_include.push(include_file("http", value)?); - } + "--http-include" => { + let value = arg.get_arg(optarg)?; + user.http_include.push(include_file("http", value)?); + } - "--main-conf" => { - arg.push_to(&mut user.main_conf, optarg)?; - } + "--main-conf" => { + arg.push_to(&mut user.main_conf, optarg)?; + } - "--main-include" => { - let value = arg.get_arg(optarg)?; - user.main_include.push(include_file("main", value)?); - } + "--main-include" => { + let value = arg.get_arg(optarg)?; + user.main_include.push(include_file("main", value)?); + } - "--errlog-level" => { - user.errlog_level = arg.parse_to(optarg)?; - } + "--errlog-level" => { + user.errlog_level = arg.parse_to(optarg)?; + } - "--nginx" => { - let value = arg.get_arg(optarg)?; - user.nginx_bin = Some(value); + "--nginx" => { + let value = arg.get_arg(optarg)?; + user.nginx_bin = Some(value); - if show_version { - return Ok(Action::Version(user.nginx_bin)); + if show_version { + return Ok(Action::Version(user.nginx_bin)); + } } - } - "--no-stream" => { - user.no_stream = true; - } + "--no-stream" => { + user.no_stream = true; + } - "--stream-conf" => { - arg.push_to(&mut user.stream_conf, optarg)?; - } + "--stream-conf" => { + arg.push_to(&mut user.stream_conf, optarg)?; + } - "--ns" => { - arg.push_to(&mut user.nameservers, optarg)?; - } + "--ns" => { + arg.push_to(&mut user.nameservers, optarg)?; + } - "--resolve-ipv6" => { - user.resolve_ipv6 = true; - } + "--resolve-ipv6" => { + user.resolve_ipv6 = true; + } - "--shdict" => { - arg.push_to(&mut user.user_shdicts, optarg)?; - } + "--shdict" => { + arg.push_to(&mut user.user_shdicts, optarg)?; + } - "-I" => { - arg.push_to(&mut user.lua_package_path, optarg)?; - } + "-I" => { + arg.push_to(&mut user.lua_package_path, optarg)?; + } - "-j" => { - user.jit_cmd = Some(arg.parse_to(optarg)?); - } + "-j" => { + user.jit_cmd = Some(arg.parse_to(optarg)?); + } - "-l" => { - let value = arg.get_arg(optarg)?; - user.inline_lua - .push(format!("require({})", value.lua_quote())); - } + "-l" => { + let value = arg.get_arg(optarg)?; + user.inline_lua + .push(format!("require({})", value.lua_quote())); + } - "-e" => { - arg.push_to(&mut user.inline_lua, optarg)?; - } + "-e" => { + arg.push_to(&mut user.inline_lua, optarg)?; + } + + "--" => { + user.lua_file = optarg.take(); + end_of_args = true; + } - "--" => { - user.lua_file = optarg.take(); - end_of_args = true; + _ => { + if arg.is_opt() { + return Err(ArgError::UnknownArgument(arg)); + } else { + end_of_args = true; + user.lua_file = Some(arg.clone()); + } + } } - _ => { - if arg.is_opt() { - return Err(ArgError::UnknownArgument(arg)); + if let Some(value) = optarg.take() { + if combined_opt_arg { + return Err(ArgError::InvalidValue { + arg, + value, + err: "arg does not take a value".to_string(), + }); } else { - end_of_args = true; - user.lua_file = Some(arg.clone()); + args.push_front(value); } } + + if end_of_args { + break; + } + } + + user.lua_args.extend(args); + + if show_version { + return Ok(Action::Version(user.nginx_bin)); + } + + if user.inline_lua.is_empty() && user.lua_file.is_none() && user.jit_cmd.is_none() { + return Err(ArgError::NoLuaInput); } - if let Some(value) = optarg.take() { - if combined_opt_arg { - return Err(ArgError::InvalidValue { - arg, - value, - err: "arg does not take a value".to_string(), - }); - } else { - args.push_front(value); + if let Some(fname) = &user.lua_file { + if File::open(fname).is_err() { + return Err(ArgError::LuaFileNotFound(fname.to_string())); } } - if end_of_args { - break; + if user.nameservers.is_empty() { + user.nameservers.extend(discover_system_nameservers()); } + + Ok(Action::Main(Box::new(user))) } +} - user.lua_args.extend(args); +#[cfg(test)] +mod tests { + use super::*; - if show_version { - return Ok(Action::Version(user.nginx_bin)); + macro_rules! action { + ( $( $x:expr ),* ) => { + { + let v = vec![$($x.to_string(),)*]; + Action::try_from(v) + } + }; } - if user.inline_lua.is_empty() && user.lua_file.is_none() && user.jit_cmd.is_none() { - return Err(ArgError::NoLuaInput); + macro_rules! svec { + ( $( $x:expr ),* ) => { + { + let v: Vec = vec![$($x.to_string(),)*]; + v + } + }; } - if let Some(fname) = &user.lua_file { - if File::open(fname).is_err() { - return Err(ArgError::LuaFileNotFound(fname.to_string())); + #[test] + fn empty_args() { + assert!(matches!(action!(), Err(ArgError::EmptyArgv0))); + } + + #[test] + fn no_lua() { + assert_eq!(Err(ArgError::NoLuaInput), action!("bin")); + } + + #[test] + fn version_action() { + let exp = Ok(Action::Version(None)); + + assert_eq!(exp, action!("bin", "--version")); + assert_eq!(exp, action!("bin", "-V")); + assert_eq!(exp, action!("bin", "-v")); + } + + #[test] + fn version_action_with_additional_args() { + let exp = Ok(Action::Version(None)); + assert_eq!(exp, action!("bin", "-v", "--no-stream")); + assert_eq!( + exp, + action!("bin", "-v", "--no-stream", "-e", "ngx.sleep(1)") + ); + } + + #[test] + fn version_action_with_nginx_bin() { + let exp = Ok(Action::Version(Some("my-nginx".to_string()))); + + assert_eq!(exp, action!("bin", "--version", "--nginx", "my-nginx")); + assert_eq!(exp, action!("bin", "-V", "--nginx", "my-nginx")); + assert_eq!(exp, action!("bin", "-v", "--nginx", "my-nginx")); + + assert_eq!(exp, action!("bin", "--version", "--nginx=my-nginx")); + assert_eq!(exp, action!("bin", "-V", "--nginx=my-nginx")); + assert_eq!(exp, action!("bin", "-v", "--nginx=my-nginx")); + + assert_eq!(exp, action!("bin", "--nginx", "my-nginx", "--version")); + assert_eq!(exp, action!("bin", "--nginx", "my-nginx", "-V")); + assert_eq!(exp, action!("bin", "--nginx", "my-nginx", "-v")); + + assert_eq!(exp, action!("bin", "--nginx=my-nginx", "--version")); + assert_eq!(exp, action!("bin", "--nginx=my-nginx", "-V")); + assert_eq!(exp, action!("bin", "--nginx=my-nginx", "-v")); + } + + #[test] + fn help_action() { + let exp = Ok(Action::Help); + assert_eq!(exp, action!("bin", "-h")); + assert_eq!(exp, action!("bin", "--help")); + } + + #[test] + fn help_action_extra_args() { + let exp = Ok(Action::Help); + assert_eq!(exp, action!("bin", "-h", "--no-stream")); + assert_eq!(exp, action!("bin", "--help", "--no-stream")); + assert_eq!(exp, action!("bin", "--no-stream", "-h")); + assert_eq!(exp, action!("bin", "--no-stream", "--help")); + } + + #[test] + fn jit_command_only() { + // resty-cli doesn't actually complain if you pass in a luajit command + // with no -e or lua file + assert!(matches!(action!("bin", "-j", "v"), Ok(Action::Main(_)))); + } + + #[test] + fn opts_that_require_an_arg() { + let opts = vec![ + "--stap-opts", + "--gdb-opts", + "--user-runner", + "--valgrind-opts", + "-c", + "--http-conf", + "--http-include", + "--main-conf", + "--main-include", + "--errlog-level", + "--nginx", + "--stream-conf", + "--ns", + "--shdict", + "-I", + "-j", + "-l", + "-e", + ]; + + for opt in opts { + assert!(matches!( + action!("bin", "-e", "test", opt), + Err(ArgError::MissingValue(_)) + )); } } - if user.nameservers.is_empty() { - user.nameservers.extend(discover_system_nameservers()); + #[test] + fn opts_that_dont_take_an_arg() { + let opts = vec![ + "--no-stream", + "--gdb", + "--rr", + "--stap", + "--valgrind", + "--resolve-ipv6", + ]; + + for opt in opts { + let with_eq = format!("{}=123", opt); + + assert_eq!( + Err(ArgError::InvalidValue { + arg: opt.to_string(), + value: "123".to_string(), + err: "arg does not take a value".to_string() + }), + action!("bin", with_eq) + ); + } } - Ok(Action::Main(Box::new(user))) + #[test] + fn happy_paths() { + #[rustfmt::skip] + let act = action!( + "bin", + "--no-stream", + "--main-conf", "main-1", + "-c", "20", + "--ns", "1.2.3.4", + "--http-conf", "http-1", + "-l", "my-require", + "--ns", "5.6.7.8", + + "--shdict=my_shdict 1m", + + "-I", "/my-include-dir", + "-e", "expr 1", + + "-l=my-other-require", + + "--http-conf", "http-2", + "-j", "off", + "--nginx", "my-nginx", + "-e", "expr 2", + "--main-conf", "main-2", + + "-e=expr 3" + ); + + assert!(matches!(act, Ok(Action::Main(_)))); + + let Action::Main(args) = act.unwrap() else { + panic!("oops") + }; + + assert_eq!(20, args.worker_connections); + assert!(args.no_stream); + + assert_eq!( + svec!["1.2.3.4", "5.6.7.8"], + args.nameservers + .iter() + .map(String::from) + .collect::>() + ); + + assert_eq!(svec!["http-1", "http-2"], args.http_conf); + assert_eq!(svec!["main-1", "main-2"], args.main_conf); + + assert!(matches!(args.jit_cmd, Some(JitCmd::Off))); + + assert_eq!(Some("my-nginx".to_string()), args.nginx_bin); + + assert_eq!( + svec![ + "require([=[my-require]=])", + "expr 1", + "require([=[my-other-require]=])", + "expr 2", + "expr 3" + ], + args.inline_lua + ); + + assert_eq!(None, args.lua_file); + + assert!(matches!(args.runner, Runner::Default)); + + dbg!(&args); + } } diff --git a/src/lua.rs b/src/lua.rs index c92a870..aefc517 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -131,7 +131,7 @@ struct LuaGenerator<'a> { } impl<'a> LuaGenerator<'a> { - pub fn generate(mut self) -> Result, std::io::Error> { + pub(crate) fn generate(mut self) -> Result, std::io::Error> { self.buf.append("local gen"); self.buf.append("do"); self.buf.indent(); diff --git a/src/main.rs b/src/main.rs index 30f1722..ef2ea9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ mod util; use std::process::exit; fn main() { - match cli::init(std::env::args().collect()) { + match cli::Action::try_from(std::env::args()) { Err(e) => { eprintln!("{}", e); exit(e.exit_code()); diff --git a/src/nginx.rs b/src/nginx.rs index 2316013..1834725 100644 --- a/src/nginx.rs +++ b/src/nginx.rs @@ -12,7 +12,7 @@ const BLOCK_OPEN: &str = "{"; const BLOCK_CLOSE: &str = "}"; #[derive(Debug, Default)] -pub struct ConfBuilder { +pub(crate) struct ConfBuilder { events: Option>, main: Option>, stream_enabled: bool, @@ -23,11 +23,11 @@ pub struct ConfBuilder { } impl ConfBuilder { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Default::default() } - pub fn events(mut self, t: T) -> Self + pub(crate) fn events(mut self, t: T) -> Self where T: IntoIterator, { @@ -35,7 +35,7 @@ impl ConfBuilder { self } - pub fn main(mut self, t: T) -> Self + pub(crate) fn main(mut self, t: T) -> Self where T: IntoIterator, { @@ -43,7 +43,7 @@ impl ConfBuilder { self } - pub fn http(mut self, t: T) -> Self + pub(crate) fn http(mut self, t: T) -> Self where T: IntoIterator, { @@ -51,7 +51,7 @@ impl ConfBuilder { self } - pub fn stream(mut self, t: T, enabled: bool) -> Self + pub(crate) fn stream(mut self, t: T, enabled: bool) -> Self where T: IntoIterator, { @@ -60,7 +60,7 @@ impl ConfBuilder { self } - pub fn lua(mut self, t: T) -> Self + pub(crate) fn lua(mut self, t: T) -> Self where T: IntoIterator, { @@ -68,12 +68,12 @@ impl ConfBuilder { self } - pub fn resty_compat_version(mut self, v: u64) -> Self { + pub(crate) fn resty_compat_version(mut self, v: u64) -> Self { self.resty_compat_version = Some(v); self } - pub fn render(self, buf: &mut T) -> io::Result<()> + pub(crate) fn render(self, buf: &mut T) -> io::Result<()> where T: io::Write, { @@ -108,13 +108,13 @@ impl From for Conf { } struct Conf { - pub events: Vec, - pub main: Vec, - pub stream_enabled: bool, - pub stream: Vec, - pub http: Vec, - pub lua: Vec, - pub resty_compat_version: u64, + events: Vec, + main: Vec, + stream_enabled: bool, + stream: Vec, + http: Vec, + lua: Vec, + resty_compat_version: u64, } impl Conf { @@ -375,7 +375,7 @@ impl Conf { } } -pub fn find_nginx_bin(nginx: Option) -> PathBuf { +pub(crate) fn find_nginx_bin(nginx: Option) -> PathBuf { if let Some(path) = nginx { return PathBuf::from(path); } @@ -399,7 +399,7 @@ pub fn find_nginx_bin(nginx: Option) -> PathBuf { PathBuf::from("nginx") } -pub fn get_resty_compat_version() -> u64 { +pub(crate) fn get_resty_compat_version() -> u64 { // TODO: maybe make this a build config item? if let Some(value) = env::var_os(RESTY_COMPAT_VAR) { let value = value.to_str().unwrap(); @@ -416,11 +416,11 @@ pub fn get_resty_compat_version() -> u64 { } } -pub struct Exec { - pub prefix: String, - pub runner: Runner, - pub bin: String, - pub label: Option, +pub(crate) struct Exec { + pub(crate) prefix: String, + pub(crate) runner: Runner, + pub(crate) bin: String, + pub(crate) label: Option, } impl From for Command { @@ -506,8 +506,8 @@ impl From for Command { } } -#[derive(Default, Debug)] -pub enum Runner { +#[derive(Default, Debug, PartialEq, Eq)] +pub(crate) enum Runner { #[default] Default, RR, @@ -558,7 +558,7 @@ impl Runner { } } - pub fn update(&mut self, new: Runner) -> Result<(), ArgError> { + pub(crate) fn update(&mut self, new: Runner) -> Result<(), ArgError> { if let Runner::Default = self { *self = new; Ok(()) diff --git a/src/run.rs b/src/run.rs index b4405c1..5c4a2e5 100644 --- a/src/run.rs +++ b/src/run.rs @@ -45,7 +45,7 @@ fn block_wait(mut proc: Child) -> i32 { } } -pub fn run(mut cmd: Command) -> i32 { +pub(crate) fn run(mut cmd: Command) -> i32 { let proc = match cmd.spawn() { Ok(child) => child, Err(e) => { diff --git a/src/types.rs b/src/types.rs index 88047ab..db44626 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,40 +1,19 @@ -use libc::{c_char, mkdtemp}; -use std::ffi; +use crate::util::tempdir; use std::fmt::{Debug, Display, Formatter, Result as FmtResult}; use std::fs; -use std::io; use std::net; use std::path::PathBuf; use std::string::ToString; use thiserror::Error as ThisError; -const MKDTEMP_TEMPLATE: &str = "/tmp/resty_XXXXXX"; - -pub fn tempdir(tpl: Option<&str>) -> io::Result { - let tpl = ffi::CString::new(tpl.unwrap_or(MKDTEMP_TEMPLATE)).unwrap(); - unsafe { - let res = mkdtemp(tpl.as_ptr() as *mut c_char); - - if res.is_null() { - return Err(std::io::Error::last_os_error()); - } - - Ok(ffi::CStr::from_ptr(res).to_str().unwrap().to_string()) - } -} - -#[test] -fn pls_dont_die() { - assert!(tempdir(Some("/tmp/weeee")).is_err()); - assert!(tempdir(None).is_ok()); -} - fn trim_brackets(s: &str) -> &str { s.strip_prefix('[') .and_then(|st| st.strip_suffix(']')) .unwrap_or(s) } +// Rust's IP Address type (std::net::IpAddr) cannot be parsed from a string +// when the string is wrapped in [brackets], so we have our own type. #[derive(Debug, PartialEq, Eq, Clone)] pub(crate) struct IpAddr(String); @@ -50,6 +29,12 @@ impl From for String { } } +impl From<&IpAddr> for String { + fn from(val: &IpAddr) -> Self { + val.0.to_owned() + } +} + impl std::str::FromStr for IpAddr { type Err = String; @@ -61,21 +46,6 @@ impl std::str::FromStr for IpAddr { } } -#[test] -fn test_ip_addr_from_str() { - assert_eq!(Ok(IpAddr("[::1]".to_string())), "[::1]".parse::()); - - assert_eq!( - Ok(IpAddr("[2003:dead:beef:4dad:23:46:bb:101]".to_string())), - "[2003:dead:beef:4dad:23:46:bb:101]".parse::() - ); - - assert_eq!( - Ok(IpAddr("127.0.0.1".to_string())), - "127.0.0.1".parse::() - ); -} - #[derive(Debug, PartialEq, Eq, Clone)] pub(crate) struct Shdict(String); @@ -150,43 +120,7 @@ impl std::str::FromStr for Shdict { } } -#[test] -fn shdict_from_str() { - fn shdict(name: &str, size: &str) -> Shdict { - Shdict(format!("{} {}", name, size)) - } - - fn must_parse(input: &str, (name, size): (&str, &str)) { - assert_eq!(Ok(shdict(name, size)), input.parse::()) - } - - fn must_not_parse(s: &str) { - assert_eq!(Err(InvalidShdict::from(s)), s.parse::()) - } - - must_parse("foo_1 10k", ("foo_1", "10k")); - must_parse("foo_1 10K", ("foo_1", "10K")); - must_parse("foo_1 10m", ("foo_1", "10m")); - must_parse("foo_1 10M", ("foo_1", "10M")); - must_parse("cats_dogs 20000", ("cats_dogs", "20000")); - - must_not_parse(""); - must_not_parse("foo 10 extra"); - must_not_parse("foo 10 extra extra"); - must_not_parse("- 10"); - must_not_parse(" foo 10"); - must_not_parse("foo"); - must_not_parse("foo f"); - must_not_parse("foo 10.0"); - must_not_parse("foo 10b"); - must_not_parse("foo 10g"); - must_not_parse("foo 10km"); - must_not_parse("fo--o 10k"); - must_not_parse("foo -10k"); - must_not_parse("foo -10"); -} - -#[derive(Clone, Debug, Default, strum_macros::Display, strum_macros::EnumString)] +#[derive(Clone, Debug, Default, strum_macros::Display, strum_macros::EnumString, PartialEq, Eq)] #[strum(serialize_all = "lowercase")] pub(crate) enum LogLevel { Debug, @@ -200,7 +134,7 @@ pub(crate) enum LogLevel { Emerg, } -#[derive(Clone, Debug, strum_macros::EnumString)] +#[derive(Clone, Debug, strum_macros::EnumString, PartialEq, Eq)] #[strum(serialize_all = "lowercase")] pub(crate) enum JitCmd { /// Use LuaJIT's jit.v module to output brief info of the @@ -251,9 +185,8 @@ impl Display for Prefix { impl Prefix { pub(crate) fn new() -> Result { - let tmp = tempdir(None)?; + let root = tempdir()?; - let root = PathBuf::from(tmp); let conf = root.join("conf"); fs::create_dir_all(&root)?; @@ -279,15 +212,15 @@ pub(crate) struct Buf { } impl Buf { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self::default() } - pub fn newline(&mut self) { + pub(crate) fn newline(&mut self) { self.lines.push(String::new()); } - pub fn append(&mut self, s: &str) { + pub(crate) fn append(&mut self, s: &str) { let mut line = String::new(); if self.indent > 0 { @@ -298,22 +231,22 @@ impl Buf { self.lines.push(line); } - pub fn finalize(self) -> Vec { + pub(crate) fn finalize(self) -> Vec { self.lines } - pub fn indent(&mut self) { + pub(crate) fn indent(&mut self) { self.indent += 1 } - pub fn dedent(&mut self) { + pub(crate) fn dedent(&mut self) { assert!(self.indent > 0); self.indent -= 1 } } -#[derive(ThisError, Debug)] -pub enum ArgError { +#[derive(ThisError, Debug, PartialEq, Eq)] +pub(crate) enum ArgError { #[error("ERROR: could not find {0} include file '{1}'")] MissingInclude(String, String), @@ -341,10 +274,13 @@ pub enum ArgError { #[error("Lua input file {0} not found.")] LuaFileNotFound(String), + + #[error("ARGV[0] is empty")] + EmptyArgv0, } impl ArgError { - pub fn exit_code(&self) -> i32 { + pub(crate) fn exit_code(&self) -> i32 { match self { // I/O error Self::MissingInclude(_, _) => 2, @@ -357,6 +293,7 @@ impl ArgError { Self::Conflict(_, _) => 25, Self::UnknownArgument(_) => 1, + Self::EmptyArgv0 => 1, Self::InvalidValue { arg: _, @@ -372,3 +309,59 @@ impl ArgError { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ip_addr_from_str() { + assert_eq!(Ok(IpAddr("[::1]".to_string())), "[::1]".parse::()); + + assert_eq!( + Ok(IpAddr("[2003:dead:beef:4dad:23:46:bb:101]".to_string())), + "[2003:dead:beef:4dad:23:46:bb:101]".parse::() + ); + + assert_eq!( + Ok(IpAddr("127.0.0.1".to_string())), + "127.0.0.1".parse::() + ); + } + + #[test] + fn shdict_from_str() { + fn shdict(name: &str, size: &str) -> Shdict { + Shdict(format!("{} {}", name, size)) + } + + fn must_parse(input: &str, (name, size): (&str, &str)) { + assert_eq!(Ok(shdict(name, size)), input.parse::()) + } + + fn must_not_parse(s: &str) { + assert_eq!(Err(InvalidShdict::from(s)), s.parse::()) + } + + must_parse("foo_1 10k", ("foo_1", "10k")); + must_parse("foo_1 10K", ("foo_1", "10K")); + must_parse("foo_1 10m", ("foo_1", "10m")); + must_parse("foo_1 10M", ("foo_1", "10M")); + must_parse("cats_dogs 20000", ("cats_dogs", "20000")); + + must_not_parse(""); + must_not_parse("foo 10 extra"); + must_not_parse("foo 10 extra extra"); + must_not_parse("- 10"); + must_not_parse(" foo 10"); + must_not_parse("foo"); + must_not_parse("foo f"); + must_not_parse("foo 10.0"); + must_not_parse("foo 10b"); + must_not_parse("foo 10g"); + must_not_parse("foo 10km"); + must_not_parse("fo--o 10k"); + must_not_parse("foo -10k"); + must_not_parse("foo -10"); + } +} diff --git a/src/util.rs b/src/util.rs index ac63575..546853c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,33 @@ use crate::types::IpAddr; +use libc::{c_char, mkdtemp}; +use std::ffi::{CStr, CString}; use std::fs; -use std::io::{BufRead, BufReader, Read}; +use std::io::{self, BufRead, BufReader, ErrorKind, Read}; +use std::path::PathBuf; + +fn impl_tempdir(tpl: &str) -> io::Result { + use io::Error; + + let tpl = CString::new(tpl).map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; + + unsafe { + let ptr: *const c_char = mkdtemp(tpl.as_ptr() as *mut c_char); + + if ptr.is_null() { + return Err(std::io::Error::last_os_error()); + } + + CStr::from_ptr(ptr) + } + .to_str() + .map_err(|e| Error::new(ErrorKind::Other, e)) + .map(PathBuf::from) +} + +pub(crate) fn tempdir() -> io::Result { + const MKDTEMP_TEMPLATE: &str = "/tmp/resty_XXXXXX"; + impl_tempdir(MKDTEMP_TEMPLATE) +} fn impl_try_parse_resolv_conf(buf: T) -> Vec { BufReader::new(buf) @@ -160,12 +187,13 @@ mod tests { #[test] fn test_impl_try_parse_resolv_conf() { - fn addr>(s: T) -> IpAddr { - s.as_ref().parse().unwrap() - } - - fn addrs>(s: Vec) -> Vec { - s.iter().map(addr).collect() + macro_rules! addrs { + ( $( $x:expr ),* ) => { + { + let v = vec![$($x.parse::().unwrap(),)*]; + v + } + }; } let input = r##"# This is /run/systemd/resolve/stub-resolv.conf managed by man:systemd-resolved(8). @@ -193,14 +221,14 @@ options edns0 trust-ad search example.com"##; assert_eq!( - vec![addr("127.0.0.53")], + addrs!["127.0.0.53"], impl_try_parse_resolv_conf(input.as_bytes()) ); let input = "nameserver 127.0.0.53"; assert_eq!( - vec![addr("127.0.0.53")], + addrs!["127.0.0.53"], impl_try_parse_resolv_conf(input.as_bytes()) ); @@ -211,7 +239,7 @@ nameserver 127.0.0.3 oops extra stuff "##; assert_eq!( - vec![addr("127.0.0.2")], + addrs!["127.0.0.2"], impl_try_parse_resolv_conf(input.as_bytes()) ); @@ -222,7 +250,7 @@ nameserver 127.0.0.3 "##; assert_eq!( - addrs(vec!["127.0.0.1", "127.0.0.2", "127.0.0.3",]), + addrs!["127.0.0.1", "127.0.0.2", "127.0.0.3"], impl_try_parse_resolv_conf(input.as_bytes()) ); @@ -242,7 +270,7 @@ nameserver 127.0.0.12 "##; assert_eq!( - addrs(vec![ + addrs![ "127.0.0.1", "127.0.0.2", "127.0.0.3", @@ -253,9 +281,47 @@ nameserver 127.0.0.12 "127.0.0.8", "127.0.0.9", "127.0.0.10", - "127.0.0.11", - ]), + "127.0.0.11" + ], impl_try_parse_resolv_conf(input.as_bytes()) ); } + + #[test] + fn temp_dir_invalid_template() { + let res = impl_tempdir("/tmp/weeee"); + assert!(res.is_err()); + + let e = res.unwrap_err(); + assert!(matches!(e.kind(), ErrorKind::InvalidInput)); + } + + #[test] + fn temp_dir_invalid_template_string() { + let res = impl_tempdir("/tmp/null_\0_foo_XXXXXX"); + assert!(res.is_err()); + + let e = res.unwrap_err(); + assert!(matches!(e.kind(), ErrorKind::InvalidInput)); + } + + #[test] + fn temp_dir_io_error() { + let res = impl_tempdir("/a/b/i-dont-exist/c/d/e/f/g/foo_XXXXXX"); + assert!(res.is_err()); + + let e = res.unwrap_err(); + assert!(matches!(e.kind(), ErrorKind::NotFound)); + } + + #[test] + fn temp_dir_happy_path() { + let result = impl_tempdir("/tmp/cargo_test_XXXXXX"); + assert!(result.is_ok()); + + let path = result.unwrap(); + assert!(path.is_dir()); + + std::fs::remove_dir_all(path).expect("cleanup of temp dir"); + } }