Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Sep 9, 2024
1 parent 6076421 commit 6be5676
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 24 deletions.
94 changes: 70 additions & 24 deletions crates/polars-io/src/path_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::Arc;
use once_cell::sync::Lazy;
use polars_core::config;
use polars_core::error::{polars_bail, to_compute_err, PolarsError, PolarsResult};
use polars_utils::pl_str::PlSmallStr;
use regex::Regex;

#[cfg(feature = "cloud")]
Expand Down Expand Up @@ -135,7 +136,63 @@ pub fn expand_paths_hive(
};

let is_cloud = is_cloud_url(first_path);
let mut out_paths = vec![];

/// Wrapper around `Vec<PathBuf>` that also tracks file extensions, so that
/// we don't have to traverse the entire list again to validate extensions.
struct OutPaths {
paths: Vec<PathBuf>,
exts: [Option<(PlSmallStr, usize)>; 2],
current_idx: usize,
}

impl OutPaths {
fn update_ext_status(
current_idx: &mut usize,
exts: &mut [Option<(PlSmallStr, usize)>; 2],
value: &Path,
) {
let ext = value
.extension()
.map(|x| PlSmallStr::from(x.to_str().unwrap()))
.unwrap_or(PlSmallStr::EMPTY);

if exts[0].is_none() {
exts[0] = Some((ext, *current_idx));
} else if exts[1].is_none() && ext != exts[0].as_ref().unwrap().0 {
exts[1] = Some((ext, *current_idx));
}

*current_idx += 1;
}

fn push(&mut self, value: PathBuf) {
{
let current_idx = &mut self.current_idx;
let exts = &mut self.exts;
Self::update_ext_status(current_idx, exts, &value);
}
self.paths.push(value)
}

fn extend(&mut self, values: impl IntoIterator<Item = PathBuf>) {
let current_idx = &mut self.current_idx;
let exts = &mut self.exts;

self.paths.extend(values.into_iter().inspect(|x| {
Self::update_ext_status(current_idx, exts, x);
}))
}

fn extend_from_slice(&mut self, values: &[PathBuf]) {
self.extend(values.iter().cloned())
}
}

let mut out_paths = OutPaths {
paths: vec![],
exts: [None, None],
current_idx: 0,
};

let mut hive_idx_tracker = HiveIdxTracker {
idx: usize::MAX,
Expand Down Expand Up @@ -337,31 +394,20 @@ pub fn expand_paths_hive(
}
}

let out_paths = if expanded_from_single_directory(paths, out_paths.as_ref()) {
// Require all file extensions to be the same when expanding a single directory.
let ext = out_paths[0].extension();

(0..out_paths.len())
.map(|i| {
let path = out_paths[i].clone();

if path.extension() != ext {
polars_bail!(
InvalidOperation: r#"directory contained paths with different file extensions: \
first path: {}, second path: {}. Please use a glob pattern to explicitly specify \
which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#,
out_paths[i - 1].to_str().unwrap(), path.to_str().unwrap()
);
};
assert_eq!(out_paths.current_idx, out_paths.paths.len());

Ok(path)
})
.collect::<PolarsResult<Vec<_>>>()?
} else {
out_paths
};
if expanded_from_single_directory(paths, out_paths.paths.as_slice()) {
if let [Some((_, i1)), Some((_, i2))] = out_paths.exts {
polars_bail!(
InvalidOperation: r#"directory contained paths with different file extensions: \
first path: {}, second path: {}. Please use a glob pattern to explicitly specify \
which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#,
&out_paths.paths[i1].to_string_lossy(), &out_paths.paths[i2].to_string_lossy()
)
}
}

Ok((Arc::new(out_paths), hive_idx_tracker.idx))
Ok((Arc::new(out_paths.paths), hive_idx_tracker.idx))
}

/// Ignores errors from `std::fs::create_dir_all` if the directory exists.
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,17 @@ def test_path_expansion_excludes_empty_files_17362(tmp_path: Path) -> None:
assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df)


@pytest.mark.write_disk
def test_path_expansion_empty_directory_does_not_panic(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

with pytest.raises(pl.exceptions.ComputeError):
pl.scan_parquet(tmp_path).collect()

with pytest.raises(pl.exceptions.ComputeError):
pl.scan_parquet(tmp_path / "**/*").collect()


@pytest.mark.write_disk
def test_scan_single_dir_differing_file_extensions_raises_17436(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
Expand Down

0 comments on commit 6be5676

Please sign in to comment.