From 6be56761a4ad9e5a72dad6227fb3d8c3b60bc977 Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Mon, 9 Sep 2024 16:58:11 +1000 Subject: [PATCH] c --- crates/polars-io/src/path_utils/mod.rs | 94 +++++++++++++++++++------- py-polars/tests/unit/io/test_scan.py | 11 +++ 2 files changed, 81 insertions(+), 24 deletions(-) diff --git a/crates/polars-io/src/path_utils/mod.rs b/crates/polars-io/src/path_utils/mod.rs index 5c4e48f7e6e43..4b3bd61a2790c 100644 --- a/crates/polars-io/src/path_utils/mod.rs +++ b/crates/polars-io/src/path_utils/mod.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use once_cell::sync::Lazy; use polars_core::config; use polars_core::error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use regex::Regex; #[cfg(feature = "cloud")] @@ -135,7 +136,63 @@ pub fn expand_paths_hive( }; let is_cloud = is_cloud_url(first_path); - let mut out_paths = vec![]; + + /// Wrapper around `Vec` that also tracks file extensions, so that + /// we don't have to traverse the entire list again to validate extensions. + struct OutPaths { + paths: Vec, + exts: [Option<(PlSmallStr, usize)>; 2], + current_idx: usize, + } + + impl OutPaths { + fn update_ext_status( + current_idx: &mut usize, + exts: &mut [Option<(PlSmallStr, usize)>; 2], + value: &Path, + ) { + let ext = value + .extension() + .map(|x| PlSmallStr::from(x.to_str().unwrap())) + .unwrap_or(PlSmallStr::EMPTY); + + if exts[0].is_none() { + exts[0] = Some((ext, *current_idx)); + } else if exts[1].is_none() && ext != exts[0].as_ref().unwrap().0 { + exts[1] = Some((ext, *current_idx)); + } + + *current_idx += 1; + } + + fn push(&mut self, value: PathBuf) { + { + let current_idx = &mut self.current_idx; + let exts = &mut self.exts; + Self::update_ext_status(current_idx, exts, &value); + } + self.paths.push(value) + } + + fn extend(&mut self, values: impl IntoIterator) { + let current_idx = &mut self.current_idx; + let exts = &mut self.exts; + + self.paths.extend(values.into_iter().inspect(|x| { + Self::update_ext_status(current_idx, exts, x); + })) + } + + fn extend_from_slice(&mut self, values: &[PathBuf]) { + self.extend(values.iter().cloned()) + } + } + + let mut out_paths = OutPaths { + paths: vec![], + exts: [None, None], + current_idx: 0, + }; let mut hive_idx_tracker = HiveIdxTracker { idx: usize::MAX, @@ -337,31 +394,20 @@ pub fn expand_paths_hive( } } - let out_paths = if expanded_from_single_directory(paths, out_paths.as_ref()) { - // Require all file extensions to be the same when expanding a single directory. - let ext = out_paths[0].extension(); - - (0..out_paths.len()) - .map(|i| { - let path = out_paths[i].clone(); - - if path.extension() != ext { - polars_bail!( - InvalidOperation: r#"directory contained paths with different file extensions: \ - first path: {}, second path: {}. Please use a glob pattern to explicitly specify \ - which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#, - out_paths[i - 1].to_str().unwrap(), path.to_str().unwrap() - ); - }; + assert_eq!(out_paths.current_idx, out_paths.paths.len()); - Ok(path) - }) - .collect::>>()? - } else { - out_paths - }; + if expanded_from_single_directory(paths, out_paths.paths.as_slice()) { + if let [Some((_, i1)), Some((_, i2))] = out_paths.exts { + polars_bail!( + InvalidOperation: r#"directory contained paths with different file extensions: \ + first path: {}, second path: {}. Please use a glob pattern to explicitly specify \ + which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#, + &out_paths.paths[i1].to_string_lossy(), &out_paths.paths[i2].to_string_lossy() + ) + } + } - Ok((Arc::new(out_paths), hive_idx_tracker.idx)) + Ok((Arc::new(out_paths.paths), hive_idx_tracker.idx)) } /// Ignores errors from `std::fs::create_dir_all` if the directory exists. diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 1bcc463bd2e7f..fa10da619a5a8 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -575,6 +575,17 @@ def test_path_expansion_excludes_empty_files_17362(tmp_path: Path) -> None: assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df) +@pytest.mark.write_disk +def test_path_expansion_empty_directory_does_not_panic(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + with pytest.raises(pl.exceptions.ComputeError): + pl.scan_parquet(tmp_path).collect() + + with pytest.raises(pl.exceptions.ComputeError): + pl.scan_parquet(tmp_path / "**/*").collect() + + @pytest.mark.write_disk def test_scan_single_dir_differing_file_extensions_raises_17436(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True)