diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 000000000000..1f678a632560 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,3 @@ +[env] +# Tune jemalloc (https://github.com/pola-rs/polars/issues/18088). +JEMALLOC_SYS_WITH_MALLOC_CONF = "dirty_decay_ms:500,muzzy_decay_ms:-1" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 335d0a32b754..ce06b799f909 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -48,7 +48,12 @@ jobs: - name: Install Python dependencies working-directory: py-polars - run: uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose - name: Set up Rust run: rustup show diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml index ed1cf6c3419c..8a599fe382a9 100644 --- a/.github/workflows/lint-global.yml +++ b/.github/workflows/lint-global.yml @@ -15,4 +15,4 @@ jobs: - name: Lint Markdown and TOML uses: dprint/check@v2.2 - name: Spell Check with Typos - uses: crate-ci/typos@v1.23.5 + uses: crate-ci/typos@v1.24.2 diff --git a/.github/workflows/test-coverage.yml b/.github/workflows/test-coverage.yml index ce642bbec306..c774bdf864c9 100644 --- a/.github/workflows/test-coverage.yml +++ b/.github/workflows/test-coverage.yml @@ -103,7 +103,12 @@ jobs: - name: Install Python dependencies working-directory: py-polars - run: uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose + run: | + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions + uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose - name: Set up Rust run: rustup component add llvm-tools-preview diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index ca717ef191f2..089ccb9f553a 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -66,6 +66,10 @@ jobs: - name: Install Python dependencies run: | pip install uv + # Install typing-extensions separately whilst the `--extra-index-url` in `requirements-ci.txt` + # doesn't have an up-to-date typing-extensions, see + # https://github.com/astral-sh/uv/issues/6028#issuecomment-2287232150 + uv pip install -U typing-extensions uv pip install --compile-bytecode -r requirements-dev.txt -r requirements-ci.txt --verbose - name: Set up Rust diff --git a/.gitignore b/.gitignore index 4e34b437a91f..6cb7cc6fa0bd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.iml *.so +*.pyd *.ipynb .ENV .env @@ -29,7 +30,6 @@ __pycache__/ .coverage # Rust -.cargo/ target/ # Data diff --git a/Cargo.lock b/Cargo.lock index 4d85a199a8f7..2ff76ee62212 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "adler32" version = "1.2.0" @@ -102,13 +108,13 @@ checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] name = "apache-avro" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ceb7c683b2f8f40970b70e39ff8be514c95b96fcb9c4af87e1ed2cb2e10801a0" +checksum = "1aef82843a0ec9f8b19567445ad2421ceeb1d711514384bdd3d49fe37102ee13" dependencies = [ + "bigdecimal", "crc32fast", "digest", - "lazy_static", "libflate 2.1.0", "log", "num-bigint", @@ -116,10 +122,11 @@ dependencies = [ "rand", "regex-lite", "serde", + "serde_bytes", "serde_json", "snap", - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "thiserror", "typed-builder", "uuid", @@ -163,15 +170,15 @@ checksum = "9d151e35f61089500b617991b791fc8bfd237ae50cd5950803758a179b41e67a" [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow-array" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" +checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" dependencies = [ "ahash", "arrow-buffer", @@ -196,9 +203,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" +checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -208,9 +215,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.1.0" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" +checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" [[package]] name = "arrow2" @@ -251,7 +258,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -262,7 +269,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -310,9 +317,9 @@ dependencies = [ [[package]] name = "aws-config" -version = "1.5.4" +version = "1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf6cfe2881cb1fcbba9ae946fb9a6480d3b7a714ca84c74925014a89ef3387a" +checksum = "4e95816a168520d72c0e7680c405a5a8c1fb6a035b4bc4b9d7b0de8e1a941697" dependencies = [ "aws-credential-types", "aws-runtime", @@ -330,7 +337,6 @@ dependencies = [ "fastrand", "hex", "http 0.2.12", - "hyper 0.14.30", "ring", "time", "tokio", @@ -353,9 +359,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87c5f920ffd1e0526ec9e70e50bf444db50b204395a0fa7016bbf9e31ea1698f" +checksum = "f42c2d4218de4dcd890a109461e2f799a1a2ba3bcd2cde9af88360f5df9266c6" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -369,6 +375,7 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", + "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -412,9 +419,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.36.0" +version = "1.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6acca681c53374bf1d9af0e317a41d12a44902ca0f2d1e10e5cb5bb98ed74f35" +checksum = "11822090cf501c316c6f75711d77b96fba30658e3867a7762e5e2f5d32d31e81" dependencies = [ "aws-credential-types", "aws-runtime", @@ -434,9 +441,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.37.0" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b79c6bdfe612503a526059c05c9ccccbf6bd9530b003673cb863e547fd7c0c9a" +checksum = "78a2a06ff89176123945d1bbe865603c4d7101bea216a550bb4d2e4e9ba74d74" dependencies = [ "aws-credential-types", "aws-runtime", @@ -456,9 +463,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.36.0" +version = "1.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32e6ecdb2bd756f3b2383e6f0588dc10a4e65f5d551e70a56e0bfe0c884673ce" +checksum = "a20a91795850826a6f456f4a48eff1dfa59a0e69bdbf5b8c50518fd372106574" dependencies = [ "aws-credential-types", "aws-runtime", @@ -591,9 +598,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.6.2" +version = "1.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce87155eba55e11768b8c1afa607f3e864ae82f03caf63258b37455b0ad02537" +checksum = "0abbf454960d0db2ad12684a1640120e7557294b0ff8e2f11236290a1b293225" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -618,9 +625,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30819352ed0a04ecf6a2f3477e344d2d1ba33d43e0f09ad9047c12e0d923616f" +checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -635,9 +642,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.0" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe321a6b21f5d8eabd0ade9c55d3d0335f3c3157fc2b3e87f05f34b539e4df5" +checksum = "6cee7cadb433c781d3299b916fbf620fea813bf38f49db282fb6858141a05cc8" dependencies = [ "base64-simd", "bytes", @@ -692,7 +699,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.4", "object", "rustc-demangle", ] @@ -731,6 +738,20 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bigdecimal" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d712318a27c7150326677b321a5fa91b55f6d9034ffd67f20319e147d40cee" +dependencies = [ + "autocfg", + "libm", + "num-bigint", + "num-integer", + "num-traits", + "serde", +] + [[package]] name = "bincode" version = "1.3.3" @@ -751,9 +772,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.3" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9ec96fe9a81b5e365f9db71fe00edc4fe4ca2cc7dcb7861f0603012a7caa210" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" dependencies = [ "arrayref", "arrayvec", @@ -801,17 +822,6 @@ dependencies = [ "alloc-stdlib", ] -[[package]] -name = "bstr" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" -dependencies = [ - "memchr", - "regex-automata", - "serde", -] - [[package]] name = "built" version = "0.7.4" @@ -831,29 +841,35 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "6fd4c6dcc3b0aea2f5c0b4b82c2b15fe39ddbc76041a310848f4706edf76bb31" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca2be1d5c43812bae364ee3f30b3afcb7877cf59f4aeb94c66f313a41d2fac9" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "bytes-utils" @@ -892,14 +908,24 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +[[package]] +name = "castaway" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5" +dependencies = [ + "rustversion", +] + [[package]] name = "cc" -version = "1.1.6" +version = "1.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +checksum = "50d2eb3cd3d1bf4529e31c215ee6f93ec5a3d536d9f578f93d9d33ee19562932" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] @@ -972,18 +998,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.11" +version = "4.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" +checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.11" +version = "4.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" +checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" dependencies = [ "anstyle", "clap_lex", @@ -1006,9 +1032,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" dependencies = [ "cc", ] @@ -1020,11 +1046,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ "crossterm", - "strum 0.26.3", - "strum_macros 0.26.4", + "strum", + "strum_macros", "unicode-width", ] +[[package]] +name = "compact_str" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -1069,9 +1110,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "core2" @@ -1084,9 +1125,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" dependencies = [ "libc", ] @@ -1361,7 +1402,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1416,9 +1457,9 @@ checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "ff" @@ -1432,13 +1473,13 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" dependencies = [ "crc32fast", "libz-ng-sys", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] @@ -1537,7 +1578,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1660,9 +1701,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" dependencies = [ "atomic-waker", "bytes", @@ -1734,6 +1775,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + [[package]] name = "hex" version = "0.4.3" @@ -1865,7 +1912,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.5", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -1903,7 +1950,7 @@ dependencies = [ "hyper 1.4.1", "hyper-util", "rustls 0.23.12", - "rustls-native-certs 0.7.1", + "rustls-native-certs 0.7.2", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -1912,9 +1959,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1965,9 +2012,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.3.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", "hashbrown", @@ -1994,11 +2041,11 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ - "hermit-abi", + "hermit-abi 0.4.0", "libc", "windows-sys 0.52.0", ] @@ -2064,9 +2111,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -2154,9 +2201,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "libflate" @@ -2221,7 +2268,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -2242,9 +2289,9 @@ dependencies = [ [[package]] name = "libz-ng-sys" -version = "1.1.15" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6409efc61b12687963e602df8ecf70e8ddacf95bc6576bcf16e3ac6328083c5" +checksum = "4436751a01da56f1277f323c80d584ffad94a3d14aecd959dd0dff75aa73a438" dependencies = [ "cmake", "libc", @@ -2252,9 +2299,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.18" +version = "1.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c15da26e5af7e25c90b37a2d75cdbf940cf4a55316de9d84c679c9b8bfabf82e" +checksum = "d2d16453e800a8cf6dd2fc3eb4bc99b786a9b90c663b8559a5b1a041bf89e472" dependencies = [ "cc", "libc", @@ -2286,9 +2333,9 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lru" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" dependencies = [ "hashbrown", ] @@ -2324,9 +2371,9 @@ dependencies = [ [[package]] name = "matrixmultiply" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" dependencies = [ "autocfg", "rawpointer", @@ -2390,13 +2437,22 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -2477,6 +2533,7 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", + "serde", ] [[package]] @@ -2546,7 +2603,7 @@ dependencies = [ "num-integer", "num-traits", "pyo3", - "rustc-hash", + "rustc-hash 1.1.0", ] [[package]] @@ -2650,9 +2707,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.2" +version = "0.36.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" dependencies = [ "memchr", ] @@ -2678,7 +2735,7 @@ dependencies = [ "rand", "reqwest", "ring", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "serde", "serde_json", "snafu", @@ -2826,7 +2883,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2896,7 +2953,7 @@ dependencies = [ [[package]] name = "polars" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "apache-avro", @@ -2926,7 +2983,7 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "arrow-array", @@ -2961,6 +3018,7 @@ dependencies = [ "parking_lot", "polars-arrow-format", "polars-error", + "polars-schema", "polars-utils", "proptest", "rand", @@ -2994,7 +3052,7 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.41.3" +version = "0.42.0" dependencies = [ "bytemuck", "either", @@ -3009,7 +3067,7 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "arrow-array", @@ -3029,6 +3087,7 @@ dependencies = [ "polars-compute", "polars-error", "polars-row", + "polars-schema", "polars-utils", "rand", "rand_distr", @@ -3037,7 +3096,6 @@ dependencies = [ "scopeguard", "serde", "serde_json", - "smartstring", "thiserror", "version_check", "xxhash-rust", @@ -3045,7 +3103,7 @@ dependencies = [ [[package]] name = "polars-doc-examples" -version = "0.41.3" +version = "0.42.0" dependencies = [ "aws-config", "aws-sdk-s3", @@ -3059,7 +3117,7 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.41.3" +version = "0.42.0" dependencies = [ "avro-schema", "object_store", @@ -3071,12 +3129,13 @@ dependencies = [ [[package]] name = "polars-expr" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "bitflags", "once_cell", "polars-arrow", + "polars-compute", "polars-core", "polars-io", "polars-json", @@ -3085,12 +3144,11 @@ dependencies = [ "polars-time", "polars-utils", "rayon", - "smartstring", ] [[package]] name = "polars-ffi" -version = "0.41.3" +version = "0.42.0" dependencies = [ "polars-arrow", "polars-core", @@ -3098,7 +3156,7 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "async-trait", @@ -3112,6 +3170,7 @@ dependencies = [ "fs4", "futures", "glob", + "hashbrown", "home", "itoa", "memchr", @@ -3125,6 +3184,7 @@ dependencies = [ "polars-error", "polars-json", "polars-parquet", + "polars-schema", "polars-time", "polars-utils", "rayon", @@ -3135,7 +3195,6 @@ dependencies = [ "serde_json", "simd-json", "simdutf8", - "smartstring", "tempfile", "tokio", "tokio-util", @@ -3145,7 +3204,7 @@ dependencies = [ [[package]] name = "polars-json" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "chrono", @@ -3165,7 +3224,7 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "bitflags", @@ -3187,14 +3246,13 @@ dependencies = [ "pyo3", "rayon", "serde_json", - "smartstring", "tokio", "version_check", ] [[package]] name = "polars-mem-engine" -version = "0.41.3" +version = "0.42.0" dependencies = [ "futures", "memmap2", @@ -3215,7 +3273,7 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "aho-corasick", @@ -3236,6 +3294,7 @@ dependencies = [ "polars-core", "polars-error", "polars-json", + "polars-schema", "polars-utils", "rand", "rand_distr", @@ -3243,14 +3302,13 @@ dependencies = [ "regex", "serde", "serde_json", - "smartstring", "unicode-reverse", "version_check", ] [[package]] name = "polars-parquet" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "async-stream", @@ -3261,6 +3319,7 @@ dependencies = [ "fallible-streaming-iterator", "flate2", "futures", + "hashbrown", "lz4", "lz4_flex", "num-traits", @@ -3280,7 +3339,7 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.41.3" +version = "0.42.0" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -3298,7 +3357,6 @@ dependencies = [ "polars-row", "polars-utils", "rayon", - "smartstring", "tokio", "uuid", "version_check", @@ -3306,11 +3364,12 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "bitflags", "bytemuck", + "bytes", "chrono", "chrono-tz", "ciborium", @@ -3336,14 +3395,47 @@ dependencies = [ "regex", "serde", "serde_json", - "smartstring", - "strum_macros 0.26.4", + "strum_macros", + "version_check", +] + +[[package]] +name = "polars-python" +version = "0.42.0" +dependencies = [ + "ahash", + "arboard", + "bytemuck", + "bytes", + "ciborium", + "either", + "itoa", + "libc", + "ndarray", + "num-traits", + "numpy", + "once_cell", + "polars", + "polars-core", + "polars-error", + "polars-io", + "polars-lazy", + "polars-ops", + "polars-parquet", + "polars-plan", + "polars-stream", + "polars-time", + "polars-utils", + "pyo3", + "recursive", + "serde_json", + "thiserror", "version_check", ] [[package]] name = "polars-row" -version = "0.41.3" +version = "0.42.0" dependencies = [ "bytemuck", "polars-arrow", @@ -3351,9 +3443,20 @@ dependencies = [ "polars-utils", ] +[[package]] +name = "polars-schema" +version = "0.42.0" +dependencies = [ + "indexmap", + "polars-error", + "polars-utils", + "serde", + "version_check", +] + [[package]] name = "polars-sql" -version = "0.41.3" +version = "0.42.0" dependencies = [ "hex", "once_cell", @@ -3364,6 +3467,7 @@ dependencies = [ "polars-ops", "polars-plan", "polars-time", + "polars-utils", "rand", "serde", "serde_json", @@ -3372,11 +3476,13 @@ dependencies = [ [[package]] name = "polars-stream" -version = "0.41.3" +version = "0.42.0" dependencies = [ "atomic-waker", "crossbeam-deque", "crossbeam-utils", + "futures", + "memmap2", "parking_lot", "pin-project-lite", "polars-core", @@ -3384,6 +3490,7 @@ dependencies = [ "polars-expr", "polars-io", "polars-mem-engine", + "polars-parquet", "polars-plan", "polars-utils", "rand", @@ -3396,7 +3503,7 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.41.3" +version = "0.42.0" dependencies = [ "atoi", "bytemuck", @@ -3411,18 +3518,19 @@ dependencies = [ "polars-utils", "regex", "serde", - "smartstring", ] [[package]] name = "polars-utils" -version = "0.41.3" +version = "0.42.0" dependencies = [ "ahash", "bytemuck", "bytes", + "compact_str", "hashbrown", "indexmap", + "libc", "memmap2", "num-traits", "once_cell", @@ -3430,7 +3538,7 @@ dependencies = [ "rand", "raw-cpuid", "rayon", - "smartstring", + "serde", "stacker", "sysinfo", "version_check", @@ -3450,9 +3558,12 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro2" @@ -3513,38 +3624,14 @@ dependencies = [ [[package]] name = "py-polars" -version = "1.4.1" +version = "1.6.0" dependencies = [ - "ahash", - "arboard", "built", - "bytemuck", - "ciborium", - "either", - "itoa", "jemallocator", "libc", "mimalloc", - "ndarray", - "num-traits", - "numpy", - "once_cell", - "polars", - "polars-core", - "polars-error", - "polars-io", - "polars-lazy", - "polars-ops", - "polars-parquet", - "polars-plan", - "polars-stream", - "polars-time", - "polars-utils", + "polars-python", "pyo3", - "recursive", - "serde_json", - "smartstring", - "thiserror", ] [[package]] @@ -3596,7 +3683,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3609,7 +3696,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3641,16 +3728,17 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" +checksum = "b22d8e7369034b9a7132bc2008cac12f2013c8132b45e0554e6e20e2617f2156" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash", + "rustc-hash 2.0.0", "rustls 0.23.12", + "socket2", "thiserror", "tokio", "tracing", @@ -3658,14 +3746,14 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.11.3" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" +checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" dependencies = [ "bytes", "rand", "ring", - "rustc-hash", + "rustc-hash 2.0.0", "rustls 0.23.12", "slab", "thiserror", @@ -3682,14 +3770,15 @@ dependencies = [ "libc", "once_cell", "socket2", + "tracing", "windows-sys 0.52.0", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -3805,7 +3894,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3834,14 +3923,14 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -3880,16 +3969,16 @@ checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "reqwest" -version = "0.12.5" +version = "0.12.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" +checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" dependencies = [ "base64 0.22.1", "bytes", "futures-channel", "futures-core", "futures-util", - "h2 0.4.5", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "http-body-util", @@ -3905,8 +3994,8 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.12", - "rustls-native-certs 0.7.1", - "rustls-pemfile 2.1.2", + "rustls-native-certs 0.7.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "serde", "serde_json", @@ -3921,7 +4010,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "winreg", + "windows-registry", ] [[package]] @@ -3968,6 +4057,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" + [[package]] name = "rustc_version" version = "0.4.0" @@ -4030,12 +4125,12 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" +checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "schannel", "security-framework", @@ -4052,9 +4147,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -4062,9 +4157,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" @@ -4227,32 +4322,42 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] +[[package]] +name = "serde_bytes" +version = "0.11.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +dependencies = [ + "serde", +] + [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "indexmap", "itoa", + "memchr", "ryu", "serde", ] @@ -4300,6 +4405,12 @@ dependencies = [ "digest", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -4373,18 +4484,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "smartstring" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb72c633efbaa2dd666986505016c32c3044395ceaf881518399d2f4127ee29" -dependencies = [ - "autocfg", - "serde", - "static_assertions", - "version_check", -] - [[package]] name = "snafu" version = "0.7.5" @@ -4450,15 +4549,15 @@ dependencies = [ [[package]] name = "stacker" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +checksum = "95a5daa25ea337c85ed954c0496e3bdd2c7308cc3b24cf7b50d04876654c579f" dependencies = [ "cc", "cfg-if", "libc", "psm", - "winapi", + "windows-sys 0.36.1", ] [[package]] @@ -4488,31 +4587,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" - [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.72", -] - [[package]] name = "strum_macros" version = "0.26.4" @@ -4523,7 +4603,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4545,9 +4625,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.72" +version = "2.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" dependencies = [ "proc-macro2", "quote", @@ -4559,14 +4639,16 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] [[package]] name = "sysinfo" -version = "0.31.0" +version = "0.31.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29a6b037e3af4ae9a9d6214198e4df53091363b2c96c88fc416a6c1bd92a2799" +checksum = "2b92e0bdf838cbc1c4c9ba14f9c97a7ec6cdcd1ae66b10e1e42775a25553f45d" dependencies = [ - "bstr", "core-foundation-sys", "libc", "memchr", @@ -4582,20 +4664,21 @@ checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4615,7 +4698,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4684,9 +4767,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.1" +version = "1.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d040ac2b29ab03b09d4129c2f5bbd012a3ac2f79d38ff506a4bf8dd34b0eac8a" +checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" dependencies = [ "backtrace", "bytes", @@ -4707,7 +4790,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4759,9 +4842,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] @@ -4796,15 +4879,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -4825,7 +4908,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4855,22 +4938,22 @@ dependencies = [ [[package]] name = "typed-builder" -version = "0.16.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34085c17941e36627a879208083e25d357243812c30e7d7387c3b954f30ade16" +checksum = "a06fbd5b8de54c5f7c91f6fe4cebb949be2125d7758e630bb58b1d831dbce600" dependencies = [ "typed-builder-macro", ] [[package]] name = "typed-builder-macro" -version = "0.16.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" +checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5023,34 +5106,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" dependencies = [ "cfg-if", "js-sys", @@ -5060,9 +5144,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5070,22 +5154,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "wasm-streams" @@ -5102,9 +5186,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", @@ -5128,11 +5212,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5168,7 +5252,7 @@ checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" dependencies = [ "windows-implement", "windows-interface", - "windows-result", + "windows-result 0.1.2", "windows-targets 0.52.6", ] @@ -5180,7 +5264,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5191,7 +5275,18 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", +] + +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result 0.2.0", + "windows-strings", + "windows-targets 0.52.6", ] [[package]] @@ -5203,13 +5298,36 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result 0.2.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" -version = "0.48.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows-targets 0.48.5", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", ] [[package]] @@ -5221,6 +5339,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -5264,6 +5391,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -5276,6 +5409,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_i686_gnu" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -5294,6 +5433,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_msvc" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -5306,6 +5451,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_x86_64_gnu" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -5330,6 +5481,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_msvc" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -5351,16 +5508,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "winreg" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - [[package]] name = "x11rb" version = "0.13.1" @@ -5386,9 +5533,9 @@ checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "xxhash-rust" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63658493314859b4dfdf3fb8c1defd61587839def09582db50b8a4e93afca6bb" +checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" [[package]] name = "zerocopy" @@ -5396,6 +5543,7 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] @@ -5407,7 +5555,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5427,18 +5575,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.0" +version = "7.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.12+zstd.1.5.6" +version = "2.0.13+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 4a6716efe131..7172176ce496 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ default-members = [ # ] [workspace.package] -version = "0.41.3" +version = "0.42.0" authors = ["Ritchie Vink "] edition = "2021" homepage = "https://www.pola.rs/" @@ -40,6 +40,7 @@ bytes = { version = "1.7" } chrono = { version = "0.4.31", default-features = false, features = ["std"] } chrono-tz = "0.8.1" ciborium = "0.2" +compact_str = { version = "0.8.0", features = ["serde"] } crossbeam-channel = "0.5.8" crossbeam-deque = "0.8.5" crossbeam-queue = "0.3" @@ -52,9 +53,10 @@ flate2 = { version = "1", default-features = false } futures = "0.3.25" hashbrown = { version = "0.14", features = ["rayon", "ahash", "serde"] } hex = "0.4.3" -indexmap = { version = "2", features = ["std"] } +indexmap = { version = "2", features = ["std", "serde"] } itoa = "1.0.6" itoap = { version = "1", features = ["simd"] } +libc = "0.2" memchr = "2.6" memmap = { package = "memmap2", version = "0.7" } multiversion = "0.7" @@ -75,12 +77,11 @@ regex = "1.9" reqwest = { version = "0.12", default-features = false } ryu = "1.0.13" scopeguard = "1.2.0" -serde = { version = "1.0.188", features = ["derive"] } +serde = { version = "1.0.188", features = ["derive", "rc"] } serde_json = "1" simd-json = { version = "0.13", features = ["known-key"] } simdutf8 = "0.1.4" slotmap = "1" -smartstring = "1" sqlparser = "0.49" stacker = "0.1" streaming-iterator = "0.1.9" @@ -96,25 +97,27 @@ version_check = "0.9.4" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } zstd = "0.13" -polars = { version = "0.41.3", path = "crates/polars", default-features = false } -polars-compute = { version = "0.41.3", path = "crates/polars-compute", default-features = false } -polars-core = { version = "0.41.3", path = "crates/polars-core", default-features = false } -polars-error = { version = "0.41.3", path = "crates/polars-error", default-features = false } -polars-expr = { version = "0.41.3", path = "crates/polars-expr", default-features = false } -polars-ffi = { version = "0.41.3", path = "crates/polars-ffi", default-features = false } -polars-io = { version = "0.41.3", path = "crates/polars-io", default-features = false } -polars-json = { version = "0.41.3", path = "crates/polars-json", default-features = false } -polars-lazy = { version = "0.41.3", path = "crates/polars-lazy", default-features = false } -polars-mem-engine = { version = "0.41.3", path = "crates/polars-mem-engine", default-features = false } -polars-ops = { version = "0.41.3", path = "crates/polars-ops", default-features = false } -polars-parquet = { version = "0.41.3", path = "crates/polars-parquet", default-features = false } -polars-pipe = { version = "0.41.3", path = "crates/polars-pipe", default-features = false } -polars-plan = { version = "0.41.3", path = "crates/polars-plan", default-features = false } -polars-row = { version = "0.41.3", path = "crates/polars-row", default-features = false } -polars-sql = { version = "0.41.3", path = "crates/polars-sql", default-features = false } -polars-stream = { version = "0.41.3", path = "crates/polars-stream", default-features = false } -polars-time = { version = "0.41.3", path = "crates/polars-time", default-features = false } -polars-utils = { version = "0.41.3", path = "crates/polars-utils", default-features = false } +polars = { version = "0.42.0", path = "crates/polars", default-features = false } +polars-compute = { version = "0.42.0", path = "crates/polars-compute", default-features = false } +polars-core = { version = "0.42.0", path = "crates/polars-core", default-features = false } +polars-error = { version = "0.42.0", path = "crates/polars-error", default-features = false } +polars-expr = { version = "0.42.0", path = "crates/polars-expr", default-features = false } +polars-ffi = { version = "0.42.0", path = "crates/polars-ffi", default-features = false } +polars-io = { version = "0.42.0", path = "crates/polars-io", default-features = false } +polars-json = { version = "0.42.0", path = "crates/polars-json", default-features = false } +polars-lazy = { version = "0.42.0", path = "crates/polars-lazy", default-features = false } +polars-mem-engine = { version = "0.42.0", path = "crates/polars-mem-engine", default-features = false } +polars-ops = { version = "0.42.0", path = "crates/polars-ops", default-features = false } +polars-parquet = { version = "0.42.0", path = "crates/polars-parquet", default-features = false } +polars-pipe = { version = "0.42.0", path = "crates/polars-pipe", default-features = false } +polars-plan = { version = "0.42.0", path = "crates/polars-plan", default-features = false } +polars-python = { version = "0.42.0", path = "crates/polars-python", default-features = false } +polars-row = { version = "0.42.0", path = "crates/polars-row", default-features = false } +polars-schema = { version = "0.42.0", path = "crates/polars-schema", default-features = false } +polars-sql = { version = "0.42.0", path = "crates/polars-sql", default-features = false } +polars-stream = { version = "0.42.0", path = "crates/polars-stream", default-features = false } +polars-time = { version = "0.42.0", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.42.0", path = "crates/polars-utils", default-features = false } [workspace.dependencies.arrow-format] package = "polars-arrow-format" @@ -122,7 +125,7 @@ version = "0.1.0" [workspace.dependencies.arrow] package = "polars-arrow" -version = "0.41.3" +version = "0.42.0" path = "crates/polars-arrow" default-features = false features = [ diff --git a/Makefile b/Makefile index 524e04eddc2c..ae13f7a525bc 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ FILTER_PIP_WARNINGS=| grep -v "don't match your environment"; test $${PIPESTATUS requirements: .venv ## Install/refresh Python project requirements @unset CONDA_PREFIX \ && $(VENV_BIN)/python -m pip install --upgrade uv \ - && $(VENV_BIN)/uv pip install --upgrade --compile-bytecode \ + && $(VENV_BIN)/uv pip install --upgrade --compile-bytecode --no-build \ -r py-polars/requirements-dev.txt \ -r py-polars/requirements-lint.txt \ -r py-polars/docs/requirements-docs.txt \ diff --git a/README.md b/README.md index 6ae8317ab915..a106db5cc3ca 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,7 @@ Refer to the [Polars CLI repository](https://github.com/pola-rs/polars-cli) for ### Blazingly fast -Polars is very fast. In fact, it is one of the best performing solutions available. See the [TPC-H benchmarks](https://www.pola.rs/benchmarks.html) results. +Polars is very fast. In fact, it is one of the best performing solutions available. See the [PDS-H benchmarks](https://www.pola.rs/benchmarks.html) results. ### Lightweight @@ -217,7 +217,7 @@ improvements point to the `main` branch of this repo. polars = { git = "https://github.com/pola-rs/polars", rev = "" } ``` -Requires Rust version `>=1.71`. +Requires Rust version `>=1.80`. ## Contributing diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index d88d966cf278..4339a8ec48c0 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -24,6 +24,7 @@ hashbrown = { workspace = true } num-traits = { workspace = true } parking_lot = { workspace = true } polars-error = { workspace = true } +polars-schema = { workspace = true } polars-utils = { workspace = true } serde = { workspace = true, optional = true } simdutf8 = { workspace = true } @@ -153,7 +154,7 @@ compute = [ "compute_take", "compute_temporal", ] -serde = ["dep:serde"] +serde = ["dep:serde", "polars-schema/serde"] simd = [] # polars-arrow diff --git a/crates/polars-arrow/src/array/binary/data.rs b/crates/polars-arrow/src/array/binary/data.rs index a45ebcca0621..2c08d94eb1b0 100644 --- a/crates/polars-arrow/src/array/binary/data.rs +++ b/crates/polars-arrow/src/array/binary/data.rs @@ -6,8 +6,8 @@ use crate::offset::{Offset, OffsetsBuffer}; impl Arrow2Arrow for BinaryArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let dtype = self.dtype.clone().into(); + let builder = ArrayDataBuilder::new(dtype) .len(self.offsets().len_proxy()) .buffers(vec![ self.offsets.clone().into_inner().into(), @@ -20,11 +20,11 @@ impl Arrow2Arrow for BinaryArray { } fn from_data(data: &ArrayData) -> Self { - let data_type = data.data_type().clone().into(); + let dtype = data.data_type().clone().into(); if data.is_empty() { // Handle empty offsets - return Self::new_empty(data_type); + return Self::new_empty(dtype); } let buffers = data.buffers(); @@ -34,7 +34,7 @@ impl Arrow2Arrow for BinaryArray { offsets.slice(data.offset(), data.len() + 1); Self { - data_type, + dtype, offsets, values: buffers[1].clone().into(), validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), diff --git a/crates/polars-arrow/src/array/binary/ffi.rs b/crates/polars-arrow/src/array/binary/ffi.rs index c135c8d3d8dd..107cf0fcb421 100644 --- a/crates/polars-arrow/src/array/binary/ffi.rs +++ b/crates/polars-arrow/src/array/binary/ffi.rs @@ -40,7 +40,7 @@ unsafe impl ToFfi for BinaryArray { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, offsets: self.offsets.clone(), values: self.values.clone(), @@ -50,7 +50,7 @@ unsafe impl ToFfi for BinaryArray { impl FromFfi for BinaryArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let offsets = unsafe { array.buffer::(1) }?; @@ -59,6 +59,6 @@ impl FromFfi for BinaryArray { // assumption that data from FFI is well constructed let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; - Self::try_new(data_type, offsets, values, validity) + Self::try_new(dtype, offsets, values, validity) } } diff --git a/crates/polars-arrow/src/array/binary/mod.rs b/crates/polars-arrow/src/array/binary/mod.rs index 87ce30f1212a..9cd06adaaabf 100644 --- a/crates/polars-arrow/src/array/binary/mod.rs +++ b/crates/polars-arrow/src/array/binary/mod.rs @@ -56,7 +56,7 @@ mod data; /// * `len` is equal to `validity.len()`, when defined. #[derive(Clone)] pub struct BinaryArray { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, @@ -69,11 +69,11 @@ impl BinaryArray { /// This function returns an error iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. /// # Implementation /// This function is `O(1)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, @@ -87,12 +87,12 @@ impl BinaryArray { polars_bail!(ComputeError: "validity mask length must match the number of values") } - if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { polars_bail!(ComputeError: "BinaryArray can only be initialized with DataType::Binary or DataType::LargeBinary") } Ok(Self { - data_type, + dtype, offsets, values, validity, @@ -105,13 +105,13 @@ impl BinaryArray { /// /// The invariants must be valid (see try_new). pub unsafe fn new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, ) -> Self { Self { - data_type, + dtype, offsets, values, validity, @@ -188,8 +188,8 @@ impl BinaryArray { /// Returns the [`ArrowDataType`] of this array. #[inline] - pub fn data_type(&self) -> &ArrowDataType { - &self.data_type + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype } /// Returns the values of this [`BinaryArray`]. @@ -246,12 +246,12 @@ impl BinaryArray { #[must_use] pub fn into_inner(self) -> (ArrowDataType, OffsetsBuffer, Buffer, Option) { let Self { - data_type, + dtype, offsets, values, validity, } = self; - (data_type, offsets, values, validity) + (dtype, offsets, values, validity) } /// Try to convert this `BinaryArray` to a `MutableBinaryArray` @@ -262,33 +262,33 @@ impl BinaryArray { match bitmap.into_mut() { // SAFETY: invariants are preserved Left(bitmap) => Left(BinaryArray::new( - self.data_type, + self.dtype, self.offsets, self.values, Some(bitmap), )), Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { (Left(values), Left(offsets)) => Left(BinaryArray::new( - self.data_type, + self.dtype, offsets, values, Some(mutable_bitmap.into()), )), (Left(values), Right(offsets)) => Left(BinaryArray::new( - self.data_type, + self.dtype, offsets.into(), values, Some(mutable_bitmap.into()), )), (Right(values), Left(offsets)) => Left(BinaryArray::new( - self.data_type, + self.dtype, offsets, values.into(), Some(mutable_bitmap.into()), )), (Right(values), Right(offsets)) => Right( MutableBinaryArray::try_new( - self.data_type, + self.dtype, offsets, values, Some(mutable_bitmap), @@ -300,38 +300,32 @@ impl BinaryArray { } else { match (self.values.into_mut(), self.offsets.into_mut()) { (Left(values), Left(offsets)) => { - Left(BinaryArray::new(self.data_type, offsets, values, None)) + Left(BinaryArray::new(self.dtype, offsets, values, None)) + }, + (Left(values), Right(offsets)) => { + Left(BinaryArray::new(self.dtype, offsets.into(), values, None)) + }, + (Right(values), Left(offsets)) => { + Left(BinaryArray::new(self.dtype, offsets, values.into(), None)) + }, + (Right(values), Right(offsets)) => { + Right(MutableBinaryArray::try_new(self.dtype, offsets, values, None).unwrap()) }, - (Left(values), Right(offsets)) => Left(BinaryArray::new( - self.data_type, - offsets.into(), - values, - None, - )), - (Right(values), Left(offsets)) => Left(BinaryArray::new( - self.data_type, - offsets, - values.into(), - None, - )), - (Right(values), Right(offsets)) => Right( - MutableBinaryArray::try_new(self.data_type, offsets, values, None).unwrap(), - ), } } } /// Creates an empty [`BinaryArray`], i.e. whose `.len` is zero. - pub fn new_empty(data_type: ArrowDataType) -> Self { - Self::new(data_type, OffsetsBuffer::new(), Buffer::new(), None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, OffsetsBuffer::new(), Buffer::new(), None) } /// Creates an null [`BinaryArray`], i.e. whose `.null_count() == .len()`. #[inline] - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { unsafe { Self::new_unchecked( - data_type, + dtype, Offsets::new_zeroed(length).into(), Buffer::new(), Some(Bitmap::new_zeroed(length)), @@ -340,7 +334,7 @@ impl BinaryArray { } /// Returns the default [`ArrowDataType`], `DataType::Binary` or `DataType::LargeBinary` - pub fn default_data_type() -> ArrowDataType { + pub fn default_dtype() -> ArrowDataType { if O::IS_LARGE { ArrowDataType::LargeBinary } else { @@ -350,12 +344,12 @@ impl BinaryArray { /// Alias for unwrapping [`Self::try_new`] pub fn new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, ) -> Self { - Self::try_new(data_type, offsets, values, validity).unwrap() + Self::try_new(dtype, offsets, values, validity).unwrap() } /// Returns a [`BinaryArray`] from an iterator of trusted length. @@ -463,13 +457,13 @@ impl Splitable for BinaryArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: lhs_offsets, values: self.values.clone(), validity: lhs_validity, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: rhs_offsets, values: self.values.clone(), validity: rhs_validity, diff --git a/crates/polars-arrow/src/array/binary/mutable.rs b/crates/polars-arrow/src/array/binary/mutable.rs index 53a8ed32bb6f..4a8dbaafe4bf 100644 --- a/crates/polars-arrow/src/array/binary/mutable.rs +++ b/crates/polars-arrow/src/array/binary/mutable.rs @@ -52,16 +52,16 @@ impl MutableBinaryArray { /// This function returns an error iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. /// # Implementation /// This function is `O(1)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, validity: Option, ) -> PolarsResult { - let values = MutableBinaryValuesArray::try_new(data_type, offsets, values)?; + let values = MutableBinaryValuesArray::try_new(dtype, offsets, values)?; if validity .as_ref() @@ -79,8 +79,8 @@ impl MutableBinaryArray { Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) } - fn default_data_type() -> ArrowDataType { - BinaryArray::::default_data_type() + fn default_dtype() -> ArrowDataType { + BinaryArray::::default_dtype() } /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots. @@ -201,8 +201,8 @@ impl MutableArray for MutableBinaryArray { array.arced() } - fn data_type(&self) -> &ArrowDataType { - self.values.data_type() + fn dtype(&self) -> &ArrowDataType { + self.values.dtype() } fn as_any(&self) -> &dyn std::any::Any { @@ -247,7 +247,7 @@ impl MutableBinaryArray { { let (validity, offsets, values) = trusted_len_unzip(iterator); - Self::try_new(Self::default_data_type(), offsets, values, validity).unwrap() + Self::try_new(Self::default_dtype(), offsets, values, validity).unwrap() } /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. @@ -271,7 +271,7 @@ impl MutableBinaryArray { iterator: I, ) -> Self { let (offsets, values) = trusted_len_values_iter(iterator); - Self::try_new(Self::default_data_type(), offsets, values, None).unwrap() + Self::try_new(Self::default_dtype(), offsets, values, None).unwrap() } /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. @@ -305,7 +305,7 @@ impl MutableBinaryArray { validity = None; } - Ok(Self::try_new(Self::default_data_type(), offsets, values, validity).unwrap()) + Ok(Self::try_new(Self::default_dtype(), offsets, values, validity).unwrap()) } /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. @@ -403,7 +403,7 @@ impl MutableBinaryArray { /// Creates a new [`MutableBinaryArray`] from a [`Iterator`] of `&[u8]`. pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { let (offsets, values) = values_iter(iterator); - Self::try_new(Self::default_data_type(), offsets, values, None).unwrap() + Self::try_new(Self::default_dtype(), offsets, values, None).unwrap() } /// Extend with a fallible iterator @@ -442,9 +442,8 @@ impl> TryPush> for MutableBinaryArray { Some(value) => { self.values.try_push(value.as_ref())?; - match &mut self.validity { - Some(validity) => validity.push(true), - None => {}, + if let Some(validity) = &mut self.validity { + validity.push(true) } }, None => { diff --git a/crates/polars-arrow/src/array/binary/mutable_values.rs b/crates/polars-arrow/src/array/binary/mutable_values.rs index 613cbb0aba9e..7da4c424d5c6 100644 --- a/crates/polars-arrow/src/array/binary/mutable_values.rs +++ b/crates/polars-arrow/src/array/binary/mutable_values.rs @@ -17,25 +17,20 @@ use crate::trusted_len::TrustedLen; /// from [`MutableBinaryArray`] in that it builds non-null [`BinaryArray`]. #[derive(Debug, Clone)] pub struct MutableBinaryValuesArray { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, } impl From> for BinaryArray { fn from(other: MutableBinaryValuesArray) -> Self { - BinaryArray::::new( - other.data_type, - other.offsets.into(), - other.values.into(), - None, - ) + BinaryArray::::new(other.dtype, other.offsets.into(), other.values.into(), None) } } impl From> for MutableBinaryArray { fn from(other: MutableBinaryValuesArray) -> Self { - MutableBinaryArray::::try_new(other.data_type, other.offsets, other.values, None) + MutableBinaryArray::::try_new(other.dtype, other.offsets, other.values, None) .expect("MutableBinaryValuesArray is consistent with MutableBinaryArray") } } @@ -50,7 +45,7 @@ impl MutableBinaryValuesArray { /// Returns an empty [`MutableBinaryValuesArray`]. pub fn new() -> Self { Self { - data_type: Self::default_data_type(), + dtype: Self::default_dtype(), offsets: Offsets::new(), values: Vec::::new(), } @@ -61,22 +56,22 @@ impl MutableBinaryValuesArray { /// # Errors /// This function returns an error iff: /// * The last offset is not equal to the values' length. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. /// # Implementation /// This function is `O(1)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, ) -> PolarsResult { try_check_offsets_bounds(&offsets, values.len())?; - if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { polars_bail!(ComputeError: "MutableBinaryValuesArray can only be initialized with DataType::Binary or DataType::LargeBinary",) } Ok(Self { - data_type, + dtype, offsets, values, }) @@ -84,8 +79,8 @@ impl MutableBinaryValuesArray { /// Returns the default [`ArrowDataType`] of this container: [`ArrowDataType::Utf8`] or [`ArrowDataType::LargeUtf8`] /// depending on the generic [`Offset`]. - pub fn default_data_type() -> ArrowDataType { - BinaryArray::::default_data_type() + pub fn default_dtype() -> ArrowDataType { + BinaryArray::::default_dtype() } /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items. @@ -96,7 +91,7 @@ impl MutableBinaryValuesArray { /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items and values. pub fn with_capacities(capacity: usize, values: usize) -> Self { Self { - data_type: Self::default_data_type(), + dtype: Self::default_dtype(), offsets: Offsets::::with_capacity(capacity), values: Vec::::with_capacity(values), } @@ -187,7 +182,7 @@ impl MutableBinaryValuesArray { /// Extract the low-end APIs from the [`MutableBinaryValuesArray`]. pub fn into_inner(self) -> (ArrowDataType, Offsets, Vec) { - (self.data_type, self.offsets, self.values) + (self.dtype, self.offsets, self.values) } } @@ -201,17 +196,17 @@ impl MutableArray for MutableBinaryValuesArray { } fn as_box(&mut self) -> Box { - let (data_type, offsets, values) = std::mem::take(self).into_inner(); - BinaryArray::new(data_type, offsets.into(), values.into(), None).boxed() + let (dtype, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(dtype, offsets.into(), values.into(), None).boxed() } fn as_arc(&mut self) -> Arc { - let (data_type, offsets, values) = std::mem::take(self).into_inner(); - BinaryArray::new(data_type, offsets.into(), values.into(), None).arced() + let (dtype, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(dtype, offsets.into(), values.into(), None).arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { @@ -239,7 +234,7 @@ impl MutableArray for MutableBinaryValuesArray { impl> FromIterator

for MutableBinaryValuesArray { fn from_iter>(iter: I) -> Self { let (offsets, values) = values_iter(iter.into_iter()); - Self::try_new(Self::default_data_type(), offsets, values).unwrap() + Self::try_new(Self::default_dtype(), offsets, values).unwrap() } } @@ -301,7 +296,7 @@ impl MutableBinaryValuesArray { I: Iterator, { let (offsets, values) = trusted_len_values_iter(iterator); - Self::try_new(Self::default_data_type(), offsets, values).unwrap() + Self::try_new(Self::default_dtype(), offsets, values).unwrap() } /// Returns a new [`MutableBinaryValuesArray`] from an iterator. diff --git a/crates/polars-arrow/src/array/binview/ffi.rs b/crates/polars-arrow/src/array/binview/ffi.rs index 8ea36c9d1de7..3fc11278dcfb 100644 --- a/crates/polars-arrow/src/array/binview/ffi.rs +++ b/crates/polars-arrow/src/array/binview/ffi.rs @@ -43,7 +43,7 @@ unsafe impl ToFfi for BinaryViewArrayGeneric { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, views: self.views.clone(), buffers: self.buffers.clone(), @@ -56,7 +56,7 @@ unsafe impl ToFfi for BinaryViewArrayGeneric { impl FromFfi for BinaryViewArrayGeneric { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let views = unsafe { array.buffer::(1) }?; @@ -66,7 +66,7 @@ impl FromFfi for BinaryViewArray let mut remaining_buffers = n_buffers - 2; if remaining_buffers <= 1 { return Ok(Self::new_unchecked_unknown_md( - data_type, + dtype, views, Arc::from([]), validity, @@ -90,7 +90,7 @@ impl FromFfi for BinaryViewArray } Ok(Self::new_unchecked_unknown_md( - data_type, + dtype, views, Arc::from(variadic_buffers), validity, diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index 85595ba4a7a8..d3fcc3c263d3 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -30,12 +30,12 @@ use polars_utils::aliases::{InitHashMaps, PlHashMap}; use polars_utils::slice::GetSaferUnchecked; use private::Sealed; -use crate::array::binview::view::{validate_binary_view, validate_utf8_only, validate_utf8_view}; +use crate::array::binview::view::{validate_binary_view, validate_utf8_only}; use crate::array::iterator::NonNullValuesIter; use crate::bitmap::utils::{BitmapIter, ZipValidity}; pub type BinaryViewArray = BinaryViewArrayGeneric<[u8]>; pub type Utf8ViewArray = BinaryViewArrayGeneric; -pub use view::View; +pub use view::{validate_utf8_view, View}; use super::Splitable; @@ -110,7 +110,7 @@ impl ViewType for [u8] { } pub struct BinaryViewArrayGeneric { - data_type: ArrowDataType, + dtype: ArrowDataType, views: Buffer, buffers: Arc<[Buffer]>, validity: Option, @@ -130,7 +130,7 @@ impl PartialEq for BinaryViewArrayGeneric { impl Clone for BinaryViewArrayGeneric { fn clone(&self) -> Self { Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), views: self.views.clone(), buffers: self.buffers.clone(), validity: self.validity.clone(), @@ -152,15 +152,43 @@ impl BinaryViewArrayGeneric { /// - the data is valid utf8 (if required) /// - The offsets match the buffers. pub unsafe fn new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, views: Buffer, buffers: Arc<[Buffer]>, validity: Option, total_bytes_len: usize, total_buffer_len: usize, ) -> Self { + // Verify the invariants + #[cfg(debug_assertions)] + { + // @TODO: Enable this. This is currently bugged with concatenate. + // let mut actual_total_buffer_len = 0; + // let mut actual_total_bytes_len = 0; + // + // for buffer in buffers.iter() { + // actual_total_buffer_len += buffer.len(); + // } + + for view in views.iter() { + // actual_total_bytes_len += view.length as usize; + if view.length > View::MAX_INLINE_SIZE { + assert!((view.buffer_idx as usize) < (buffers.len())); + assert!( + view.offset as usize + view.length as usize + <= buffers[view.buffer_idx as usize].len() + ); + } + } + + // assert_eq!(actual_total_buffer_len, total_buffer_len); + // if (total_bytes_len as u64) != UNKNOWN_LEN { + // assert_eq!(actual_total_bytes_len, total_bytes_len); + // } + } + Self { - data_type, + dtype, views, buffers, validity, @@ -175,7 +203,7 @@ impl BinaryViewArrayGeneric { /// # Safety /// The caller must ensure the invariants pub unsafe fn new_unchecked_unknown_md( - data_type: ArrowDataType, + dtype: ArrowDataType, views: Buffer, buffers: Arc<[Buffer]>, validity: Option, @@ -185,7 +213,7 @@ impl BinaryViewArrayGeneric { let total_buffer_len = total_buffer_len.unwrap_or_else(|| buffers.iter().map(|b| b.len()).sum()); Self::new_unchecked( - data_type, + dtype, views, buffers, validity, @@ -246,7 +274,7 @@ impl BinaryViewArrayGeneric { *v = update_view(*v, str_slice); } Self::new_unchecked( - self.data_type.clone(), + self.dtype.clone(), views.into(), buffers, validity, @@ -256,7 +284,7 @@ impl BinaryViewArrayGeneric { } pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, views: Buffer, buffers: Arc<[Buffer]>, validity: Option, @@ -273,31 +301,22 @@ impl BinaryViewArrayGeneric { unsafe { Ok(Self::new_unchecked_unknown_md( - data_type, views, buffers, validity, None, + dtype, views, buffers, validity, None, )) } } /// Creates an empty [`BinaryViewArrayGeneric`], i.e. whose `.len` is zero. #[inline] - pub fn new_empty(data_type: ArrowDataType) -> Self { - unsafe { Self::new_unchecked(data_type, Buffer::new(), Arc::from([]), None, 0, 0) } + pub fn new_empty(dtype: ArrowDataType) -> Self { + unsafe { Self::new_unchecked(dtype, Buffer::new(), Arc::from([]), None, 0, 0) } } /// Returns a new null [`BinaryViewArrayGeneric`] of `length`. #[inline] - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { let validity = Some(Bitmap::new_zeroed(length)); - unsafe { - Self::new_unchecked( - data_type, - Buffer::zeroed(length), - Arc::from([]), - validity, - 0, - 0, - ) - } + unsafe { Self::new_unchecked(dtype, Buffer::zeroed(length), Arc::from([]), validity, 0, 0) } } /// Returns the element at index `i` @@ -407,7 +426,7 @@ impl BinaryViewArrayGeneric { let buffers = self.buffers.as_ref(); for view in self.views.as_ref() { - unsafe { mutable.push_view_copied_unchecked(*view, buffers) } + unsafe { mutable.push_view_unchecked(*view, buffers) } } mutable.freeze().with_validity(self.validity) } @@ -525,7 +544,7 @@ impl Array for BinaryViewArrayGeneric { BinaryViewArrayGeneric::len(self) } - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { T::dtype() } @@ -592,7 +611,7 @@ impl Splitable for BinaryViewArrayGeneric { unsafe { ( Self::new_unchecked( - self.data_type.clone(), + self.dtype.clone(), lhs_views, self.buffers.clone(), lhs_validity, @@ -600,7 +619,7 @@ impl Splitable for BinaryViewArrayGeneric { self.total_buffer_len(), ), Self::new_unchecked( - self.data_type.clone(), + self.dtype.clone(), rhs_views, self.buffers.clone(), rhs_validity, diff --git a/crates/polars-arrow/src/array/binview/mutable.rs b/crates/polars-arrow/src/array/binview/mutable.rs index a4a9a2da6a52..0d7dcac94b5a 100644 --- a/crates/polars-arrow/src/array/binview/mutable.rs +++ b/crates/polars-arrow/src/array/binview/mutable.rs @@ -21,6 +21,11 @@ use crate::trusted_len::TrustedLen; const DEFAULT_BLOCK_SIZE: usize = 8 * 1024; const MAX_EXP_BLOCK_SIZE: usize = 16 * 1024 * 1024; +// Invariants: +// +// - Each view must point to a valid slice of a buffer +// - `total_buffer_len` must be equal to `completed_buffers.iter().map(Vec::len).sum()` +// - `total_bytes_len` must be equal to `views.iter().map(View::len).sum()` pub struct MutableBinaryViewArray { pub(crate) views: Vec, pub(crate) completed_buffers: Vec>, @@ -97,11 +102,41 @@ impl MutableBinaryViewArray { } } + /// Get a mutable reference to the [`Vec`] of [`View`]s in this [`MutableBinaryViewArray`]. + /// + /// # Safety + /// + /// This is safe as long as any mutation of the [`Vec`] does not break any invariants of the + /// [`MutableBinaryViewArray`] before it is read again. #[inline] - pub fn views_mut(&mut self) -> &mut Vec { + pub unsafe fn views_mut(&mut self) -> &mut Vec { &mut self.views } + /// Set the `total_bytes_len` of the [`MutableBinaryViewArray`] + /// + /// # Safety + /// + /// This should not break invariants of the [`MutableBinaryViewArray`] + #[inline] + pub unsafe fn set_total_bytes_len(&mut self, value: usize) { + #[cfg(debug_assertions)] + { + let actual_length: usize = self.views().iter().map(|v| v.length as usize).sum(); + assert_eq!(value, actual_length); + } + + self.total_bytes_len = value; + } + + pub fn total_bytes_len(&self) -> usize { + self.total_bytes_len + } + + pub fn total_buffer_len(&self) -> usize { + self.total_buffer_len + } + #[inline] pub fn views(&self) -> &[View] { &self.views @@ -144,7 +179,7 @@ impl MutableBinaryViewArray { /// - caller must allocate enough capacity /// - caller must ensure the view and buffers match. /// - The array must not have validity. - pub(crate) unsafe fn push_view_copied_unchecked(&mut self, v: View, buffers: &[Buffer]) { + pub(crate) unsafe fn push_view_unchecked(&mut self, v: View, buffers: &[Buffer]) { let len = v.length; self.total_bytes_len += len as usize; if len <= 12 { @@ -165,7 +200,7 @@ impl MutableBinaryViewArray { /// - caller must ensure the view and buffers match. /// - The array must not have validity. /// - caller must not mix use this function with other push functions. - pub unsafe fn push_view_unchecked(&mut self, mut v: View, buffers: &[Buffer]) { + pub unsafe fn push_view_unchecked_dedupe(&mut self, mut v: View, buffers: &[Buffer]) { let len = v.length; self.total_bytes_len += len as usize; if len <= 12 { @@ -268,12 +303,10 @@ impl MutableBinaryViewArray { #[inline] pub fn push_buffer(&mut self, buffer: Buffer) -> u32 { - if !self.in_progress_buffer.is_empty() { - self.completed_buffers - .push(Buffer::from(std::mem::take(&mut self.in_progress_buffer))); - } + self.finish_in_progress(); let buffer_idx = self.completed_buffers.len(); + self.total_buffer_len += buffer.len(); self.completed_buffers.push(buffer); buffer_idx as u32 } @@ -438,14 +471,17 @@ impl MutableBinaryViewArray { /// # Safety /// Same as `push_view_unchecked()`. #[inline] - pub unsafe fn extend_non_null_views_trusted_len_unchecked( + pub unsafe fn extend_non_null_views_unchecked_dedupe( &mut self, iterator: I, buffers: &[Buffer], ) where - I: TrustedLen, + I: Iterator, { - self.extend_non_null_views_unchecked(iterator, buffers); + self.reserve(iterator.size_hint().0); + for v in iterator { + self.push_view_unchecked_dedupe(v, buffers); + } } #[inline] @@ -491,7 +527,7 @@ impl MutableBinaryViewArray { #[inline] pub fn freeze_with_dtype(self, dtype: ArrowDataType) -> BinaryViewArrayGeneric { let mut arr: BinaryViewArrayGeneric = self.into(); - arr.data_type = dtype; + arr.dtype = dtype; arr } @@ -570,6 +606,135 @@ impl MutableBinaryViewArray<[u8]> { } Ok(()) } + + /// Extend from a `buffer` and `length` of items given some statistics about the lengths. + /// + /// This will attempt to dispatch to several optimized implementations. + /// + /// # Safety + /// + /// This is safe if the statistics are correct. + pub unsafe fn extend_from_lengths_with_stats( + &mut self, + buffer: &[u8], + lengths_iterator: impl Clone + ExactSizeIterator, + min_length: usize, + max_length: usize, + sum_length: usize, + ) { + let num_items = lengths_iterator.len(); + + if num_items == 0 { + return; + } + + #[cfg(debug_assertions)] + { + let (min, max, sum) = lengths_iterator.clone().map(|v| (v, v, v)).fold( + (usize::MAX, usize::MIN, 0usize), + |(cmin, cmax, csum), (emin, emax, esum)| { + (cmin.min(emin), cmax.max(emax), csum + esum) + }, + ); + + assert_eq!(min, min_length); + assert_eq!(max, max_length); + assert_eq!(sum, sum_length); + } + + assert!(sum_length <= buffer.len()); + + let mut buffer_offset = 0; + if min_length > View::MAX_INLINE_SIZE as usize + && (num_items == 1 || sum_length + self.in_progress_buffer.len() <= u32::MAX as usize) + { + let buffer_idx = self.completed_buffers().len() as u32; + let in_progress_buffer_offset = self.in_progress_buffer.len(); + + self.total_bytes_len += sum_length; + self.total_buffer_len += sum_length; + + self.in_progress_buffer + .extend_from_slice(&buffer[..sum_length]); + self.views.extend(lengths_iterator.map(|length| { + // SAFETY: We asserted before that the sum of all lengths is smaller or equal to + // the buffer length. + let view_buffer = + unsafe { buffer.get_unchecked(buffer_offset..buffer_offset + length) }; + + // SAFETY: We know that the minimum length > View::MAX_INLINE_SIZE. Therefore, this + // length is > View::MAX_INLINE_SIZE. + let view = unsafe { + View::new_noninline_unchecked( + view_buffer, + buffer_idx, + (buffer_offset + in_progress_buffer_offset) as u32, + ) + }; + buffer_offset += length; + view + })); + } else if max_length <= View::MAX_INLINE_SIZE as usize { + self.total_bytes_len += sum_length; + + // If the min and max are the same, we can dispatch to the optimized SIMD + // implementation. + if min_length == max_length { + let length = min_length; + if length == 0 { + self.views + .resize(self.views.len() + num_items, View::new_inline(&[])); + } else { + View::extend_with_inlinable_strided( + &mut self.views, + &buffer[..length * num_items], + length as u8, + ); + } + } else { + self.views.extend(lengths_iterator.map(|length| { + // SAFETY: We asserted before that the sum of all lengths is smaller or equal + // to the buffer length. + let view_buffer = + unsafe { buffer.get_unchecked(buffer_offset..buffer_offset + length) }; + + // SAFETY: We know that each view has a length <= View::MAX_INLINE_SIZE because + // the maximum length is <= View::MAX_INLINE_SIZE + let view = unsafe { View::new_inline_unchecked(view_buffer) }; + + buffer_offset += length; + + view + })); + } + } else { + // If all fails, just fall back to a base implementation. + self.reserve(num_items); + for length in lengths_iterator { + let value = &buffer[buffer_offset..buffer_offset + length]; + buffer_offset += length; + self.push_value(value); + } + } + } + + /// Extend from a `buffer` and `length` of items. + /// + /// This will attempt to dispatch to several optimized implementations. + #[inline] + pub fn extend_from_lengths( + &mut self, + buffer: &[u8], + lengths_iterator: impl Clone + ExactSizeIterator, + ) { + let (min, max, sum) = lengths_iterator.clone().map(|v| (v, v, v)).fold( + (usize::MAX, usize::MIN, 0usize), + |(cmin, cmax, csum), (emin, emax, esum)| (cmin.min(emin), cmax.max(emax), csum + esum), + ); + + // SAFETY: We just collected the right stats. + unsafe { self.extend_from_lengths_with_stats(buffer, lengths_iterator, min, max, sum) } + } } impl> Extend> for MutableBinaryViewArray { @@ -587,7 +752,7 @@ impl> FromIterator> for MutableBinar } impl MutableArray for MutableBinaryViewArray { - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { T::dtype() } @@ -643,3 +808,54 @@ impl> TryPush> for MutableBinaryView Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn roundtrip(values: &[&[u8]]) -> bool { + let buffer = values + .iter() + .flat_map(|v| v.iter().copied()) + .collect::>(); + let lengths = values.iter().map(|v| v.len()).collect::>(); + let mut bv = MutableBinaryViewArray::<[u8]>::with_capacity(values.len()); + + bv.extend_from_lengths(&buffer[..], lengths.into_iter()); + + &bv.values_iter().collect::>()[..] == values + } + + #[test] + fn extend_with_lengths_basic() { + assert!(roundtrip(&[])); + assert!(roundtrip(&[b"abc"])); + assert!(roundtrip(&[ + b"a_very_very_long_string_that_is_not_inlinable" + ])); + assert!(roundtrip(&[ + b"abc", + b"a_very_very_long_string_that_is_not_inlinable" + ])); + } + + #[test] + fn extend_with_inlinable_fastpath() { + assert!(roundtrip(&[b"abc", b"defg", b"hix"])); + assert!(roundtrip(&[b"abc", b"defg", b"hix", b"xyza1234abcd"])); + } + + #[test] + fn extend_with_inlinable_eq_len_fastpath() { + assert!(roundtrip(&[b"abc", b"def", b"hix"])); + assert!(roundtrip(&[b"abc", b"def", b"hix", b"xyz"])); + } + + #[test] + fn extend_with_not_inlinable_fastpath() { + assert!(roundtrip(&[ + b"a_very_long_string123", + b"a_longer_string_than_the_previous" + ])); + } +} diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 6542e2c9761b..67334a53aa17 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -157,12 +157,12 @@ impl View { /// Extend a `Vec` with inline views slices of `src` with `width`. /// /// This tries to use SIMD to optimize the copying and can be massively faster than doing a - /// `views.extend(src.chunks_exact(stride).map(View::new_inline))`. + /// `views.extend(src.chunks_exact(width).map(View::new_inline))`. /// /// # Panics /// - /// This function panics if `src.len()` is not divisible by `width` or if `width > - /// View::MAX_INLINE_SIZE`. + /// This function panics if `src.len()` is not divisible by `width`, `width > + /// View::MAX_INLINE_SIZE` or `width == 0`. pub fn extend_with_inlinable_strided(views: &mut Vec, src: &[u8], width: u8) { macro_rules! dispatch { ($n:ident = $match:ident in [$($v:literal),+ $(,)?] => $block:block, otherwise = $otherwise:expr) => { @@ -180,17 +180,16 @@ impl View { } let width = width as usize; - assert_eq!(src.len() % width, 0); + + assert!(width > 0); assert!(width <= View::MAX_INLINE_SIZE as usize); + + assert_eq!(src.len() % width, 0); + let num_values = src.len() / width; views.reserve(num_values); - if width == 0 { - views.resize(views.len() + num_values, View::new_inline(&[])); - return; - } - #[allow(unused_mut)] let mut src = src; @@ -427,7 +426,7 @@ fn validate_utf8(b: &[u8]) -> PolarsResult<()> { } } -pub(super) fn validate_utf8_view(views: &[View], buffers: &[Buffer]) -> PolarsResult<()> { +pub fn validate_utf8_view(views: &[View], buffers: &[Buffer]) -> PolarsResult<()> { validate_view(views, buffers, validate_utf8) } diff --git a/crates/polars-arrow/src/array/boolean/data.rs b/crates/polars-arrow/src/array/boolean/data.rs index f472348a0407..6c497896775c 100644 --- a/crates/polars-arrow/src/array/boolean/data.rs +++ b/crates/polars-arrow/src/array/boolean/data.rs @@ -28,7 +28,7 @@ impl Arrow2Arrow for BooleanArray { let values = Bitmap::from_null_buffer(NullBuffer::new(buffer)); Self { - data_type: ArrowDataType::Boolean, + dtype: ArrowDataType::Boolean, values, validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), } diff --git a/crates/polars-arrow/src/array/boolean/ffi.rs b/crates/polars-arrow/src/array/boolean/ffi.rs index bd8693f2dbb1..dfaf3ac90571 100644 --- a/crates/polars-arrow/src/array/boolean/ffi.rs +++ b/crates/polars-arrow/src/array/boolean/ffi.rs @@ -38,7 +38,7 @@ unsafe impl ToFfi for BooleanArray { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, values: self.values.clone(), } @@ -47,9 +47,9 @@ unsafe impl ToFfi for BooleanArray { impl FromFfi for BooleanArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let values = unsafe { array.bitmap(1) }?; - Self::try_new(data_type, values, validity) + Self::try_new(dtype, values, validity) } } diff --git a/crates/polars-arrow/src/array/boolean/mod.rs b/crates/polars-arrow/src/array/boolean/mod.rs index 656a6db7e89a..5cd9870fdbf4 100644 --- a/crates/polars-arrow/src/array/boolean/mod.rs +++ b/crates/polars-arrow/src/array/boolean/mod.rs @@ -45,7 +45,7 @@ use polars_error::{polars_bail, PolarsResult}; /// ``` #[derive(Clone)] pub struct BooleanArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: Bitmap, validity: Option, } @@ -55,9 +55,9 @@ impl BooleanArray { /// # Errors /// This function errors iff: /// * The validity is not `None` and its length is different from `values`'s length - /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Bitmap, validity: Option, ) -> PolarsResult { @@ -68,20 +68,20 @@ impl BooleanArray { polars_bail!(ComputeError: "validity mask length must match the number of values") } - if data_type.to_physical_type() != PhysicalType::Boolean { + if dtype.to_physical_type() != PhysicalType::Boolean { polars_bail!(ComputeError: "BooleanArray can only be initialized with a DataType whose physical type is Boolean") } Ok(Self { - data_type, + dtype, values, validity, }) } /// Alias to `Self::try_new().unwrap()` - pub fn new(data_type: ArrowDataType, values: Bitmap, validity: Option) -> Self { - Self::try_new(data_type, values, validity).unwrap() + pub fn new(dtype: ArrowDataType, values: Bitmap, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() } /// Returns an iterator over the optional values of this [`BooleanArray`]. @@ -123,8 +123,8 @@ impl BooleanArray { /// Returns the arrays' [`ArrowDataType`]. #[inline] - pub fn data_type(&self) -> &ArrowDataType { - &self.data_type + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype } /// Returns the value at index `i` @@ -238,38 +238,38 @@ impl BooleanArray { if let Some(bitmap) = self.validity { match bitmap.into_mut() { - Left(bitmap) => Left(BooleanArray::new(self.data_type, self.values, Some(bitmap))), + Left(bitmap) => Left(BooleanArray::new(self.dtype, self.values, Some(bitmap))), Right(mutable_bitmap) => match self.values.into_mut() { Left(immutable) => Left(BooleanArray::new( - self.data_type, + self.dtype, immutable, Some(mutable_bitmap.into()), )), Right(mutable) => Right( - MutableBooleanArray::try_new(self.data_type, mutable, Some(mutable_bitmap)) + MutableBooleanArray::try_new(self.dtype, mutable, Some(mutable_bitmap)) .unwrap(), ), }, } } else { match self.values.into_mut() { - Left(immutable) => Left(BooleanArray::new(self.data_type, immutable, None)), + Left(immutable) => Left(BooleanArray::new(self.dtype, immutable, None)), Right(mutable) => { - Right(MutableBooleanArray::try_new(self.data_type, mutable, None).unwrap()) + Right(MutableBooleanArray::try_new(self.dtype, mutable, None).unwrap()) }, } } } /// Returns a new empty [`BooleanArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - Self::new(data_type, Bitmap::new(), None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, Bitmap::new(), None) } /// Returns a new [`BooleanArray`] whose all slots are null / `None`. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { let bitmap = Bitmap::new_zeroed(length); - Self::new(data_type, bitmap.clone(), Some(bitmap)) + Self::new(dtype, bitmap.clone(), Some(bitmap)) } /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. @@ -352,11 +352,11 @@ impl BooleanArray { #[must_use] pub fn into_inner(self) -> (ArrowDataType, Bitmap, Option) { let Self { - data_type, + dtype, values, validity, } = self; - (data_type, values, validity) + (dtype, values, validity) } /// Creates a `[BooleanArray]` from its internal representation. @@ -365,12 +365,12 @@ impl BooleanArray { /// # Safety /// Callers must ensure all invariants of this struct are upheld. pub unsafe fn from_inner_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Bitmap, validity: Option, ) -> Self { Self { - data_type, + dtype, values, validity, } @@ -401,12 +401,12 @@ impl Splitable for BooleanArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: lhs_values, validity: lhs_validity, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: rhs_values, validity: rhs_validity, }, @@ -417,7 +417,7 @@ impl Splitable for BooleanArray { impl From for BooleanArray { fn from(values: Bitmap) -> Self { Self { - data_type: ArrowDataType::Boolean, + dtype: ArrowDataType::Boolean, values, validity: None, } diff --git a/crates/polars-arrow/src/array/boolean/mutable.rs b/crates/polars-arrow/src/array/boolean/mutable.rs index 80d689806f1d..f93707db4846 100644 --- a/crates/polars-arrow/src/array/boolean/mutable.rs +++ b/crates/polars-arrow/src/array/boolean/mutable.rs @@ -15,7 +15,7 @@ use crate::trusted_len::TrustedLen; /// This struct does not allocate a validity until one is required (i.e. push a null to it). #[derive(Debug, Clone)] pub struct MutableBooleanArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: MutableBitmap, validity: Option, } @@ -23,7 +23,7 @@ pub struct MutableBooleanArray { impl From for BooleanArray { fn from(other: MutableBooleanArray) -> Self { BooleanArray::new( - other.data_type, + other.dtype, other.values.into(), other.validity.map(|x| x.into()), ) @@ -53,9 +53,9 @@ impl MutableBooleanArray { /// # Errors /// This function errors iff: /// * The validity is not `None` and its length is different from `values`'s length - /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: MutableBitmap, validity: Option, ) -> PolarsResult { @@ -68,14 +68,14 @@ impl MutableBooleanArray { ) } - if data_type.to_physical_type() != PhysicalType::Boolean { + if dtype.to_physical_type() != PhysicalType::Boolean { polars_bail!(oos = "MutableBooleanArray can only be initialized with a DataType whose physical type is Boolean", ) } Ok(Self { - data_type, + dtype, values, validity, }) @@ -84,7 +84,7 @@ impl MutableBooleanArray { /// Creates an new [`MutableBooleanArray`] with a capacity of values. pub fn with_capacity(capacity: usize) -> Self { Self { - data_type: ArrowDataType::Boolean, + dtype: ArrowDataType::Boolean, values: MutableBitmap::with_capacity(capacity), validity: None, } @@ -101,9 +101,8 @@ impl MutableBooleanArray { #[inline] pub fn push_value(&mut self, value: bool) { self.values.push(value); - match &mut self.validity { - Some(validity) => validity.push(true), - None => {}, + if let Some(validity) = &mut self.validity { + validity.push(true) } } @@ -534,8 +533,8 @@ impl MutableArray for MutableBooleanArray { array.arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/array/dictionary/data.rs b/crates/polars-arrow/src/array/dictionary/data.rs index e7159e4bfff2..a5eda5a0fd73 100644 --- a/crates/polars-arrow/src/array/dictionary/data.rs +++ b/crates/polars-arrow/src/array/dictionary/data.rs @@ -10,7 +10,7 @@ impl Arrow2Arrow for DictionaryArray { let keys = self.keys.to_data(); let builder = keys .into_builder() - .data_type(self.data_type.clone().into()) + .data_type(self.dtype.clone().into()) .child_data(vec![to_data(self.values.as_ref())]); // SAFETY: Dictionary is valid @@ -23,9 +23,9 @@ impl Arrow2Arrow for DictionaryArray { d => panic!("unsupported dictionary type {d}"), }; - let data_type = ArrowDataType::from(data.data_type().clone()); + let dtype = ArrowDataType::from(data.data_type().clone()); assert_eq!( - data_type.to_physical_type(), + dtype.to_physical_type(), PhysicalType::Dictionary(K::KEY_TYPE) ); @@ -41,7 +41,7 @@ impl Arrow2Arrow for DictionaryArray { let values = from_data(&data.child_data()[0]); Self { - data_type, + dtype, keys, values, } diff --git a/crates/polars-arrow/src/array/dictionary/ffi.rs b/crates/polars-arrow/src/array/dictionary/ffi.rs index b22c27eacead..025a4bbb9b69 100644 --- a/crates/polars-arrow/src/array/dictionary/ffi.rs +++ b/crates/polars-arrow/src/array/dictionary/ffi.rs @@ -15,7 +15,7 @@ unsafe impl ToFfi for DictionaryArray { fn to_ffi_aligned(&self) -> Self { Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), keys: self.keys.to_ffi_aligned(), values: self.values.clone(), } @@ -28,7 +28,7 @@ impl FromFfi for DictionaryArray let validity = unsafe { array.validity() }?; let values = unsafe { array.buffer::(1) }?; - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let keys = PrimitiveArray::::try_new(K::PRIMITIVE.into(), values, validity)?; let values = array.dictionary()?.ok_or_else( @@ -37,6 +37,6 @@ impl FromFfi for DictionaryArray let values = ffi::try_from(values)?; // the assumption of this trait - DictionaryArray::::try_new_unchecked(data_type, keys, values) + DictionaryArray::::try_new_unchecked(dtype, keys, values) } } diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs index 7420b02f7891..d53970dacd98 100644 --- a/crates/polars-arrow/src/array/dictionary/mod.rs +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -126,21 +126,21 @@ unsafe impl DictionaryKey for u64 { /// use `unchecked` calls to retrieve the values #[derive(Clone)] pub struct DictionaryArray { - data_type: ArrowDataType, + dtype: ArrowDataType, keys: PrimitiveArray, values: Box, } -fn check_data_type( +fn check_dtype( key_type: IntegerType, - data_type: &ArrowDataType, - values_data_type: &ArrowDataType, + dtype: &ArrowDataType, + values_dtype: &ArrowDataType, ) -> PolarsResult<()> { - if let ArrowDataType::Dictionary(key, value, _) = data_type.to_logical_type() { + if let ArrowDataType::Dictionary(key, value, _) = dtype.to_logical_type() { if *key != key_type { polars_bail!(ComputeError: "DictionaryArray must be initialized with a DataType::Dictionary whose integer is compatible to its keys") } - if value.as_ref().to_logical_type() != values_data_type.to_logical_type() { + if value.as_ref().to_logical_type() != values_dtype.to_logical_type() { polars_bail!(ComputeError: "DictionaryArray must be initialized with a DataType::Dictionary whose value is equal to its values") } } else { @@ -155,16 +155,16 @@ impl DictionaryArray { /// This function is `O(N)` where `N` is the length of keys /// # Errors /// This function errors iff - /// * the `data_type`'s logical type is not a `DictionaryArray` - /// * the `data_type`'s keys is not compatible with `keys` - /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// * the `dtype`'s logical type is not a `DictionaryArray` + /// * the `dtype`'s keys is not compatible with `keys` + /// * the `dtype`'s values's dtype is not equal with `values.dtype()` /// * any of the keys's values is not represented in `usize` or is `>= values.len()` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, keys: PrimitiveArray, values: Box, ) -> PolarsResult { - check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + check_dtype(K::KEY_TYPE, &dtype, values.dtype())?; if keys.null_count() != keys.len() { if K::always_fits_usize() { @@ -177,7 +177,7 @@ impl DictionaryArray { } Ok(Self { - data_type, + dtype, keys, values, }) @@ -190,39 +190,39 @@ impl DictionaryArray { /// This function errors iff /// * any of the keys's values is not represented in `usize` or is `>= values.len()` pub fn try_from_keys(keys: PrimitiveArray, values: Box) -> PolarsResult { - let data_type = Self::default_data_type(values.data_type().clone()); - Self::try_new(data_type, keys, values) + let dtype = Self::default_dtype(values.dtype().clone()); + Self::try_new(dtype, keys, values) } /// Returns a new [`DictionaryArray`]. /// # Errors /// This function errors iff - /// * the `data_type`'s logical type is not a `DictionaryArray` - /// * the `data_type`'s keys is not compatible with `keys` - /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// * the `dtype`'s logical type is not a `DictionaryArray` + /// * the `dtype`'s keys is not compatible with `keys` + /// * the `dtype`'s values's dtype is not equal with `values.dtype()` /// /// # Safety /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` pub unsafe fn try_new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, keys: PrimitiveArray, values: Box, ) -> PolarsResult { - check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + check_dtype(K::KEY_TYPE, &dtype, values.dtype())?; Ok(Self { - data_type, + dtype, keys, values, }) } /// Returns a new empty [`DictionaryArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - let values = Self::try_get_child(&data_type).unwrap(); + pub fn new_empty(dtype: ArrowDataType) -> Self { + let values = Self::try_get_child(&dtype).unwrap(); let values = new_empty_array(values.clone()); Self::try_new( - data_type, + dtype, PrimitiveArray::::new_empty(K::PRIMITIVE.into()), values, ) @@ -231,11 +231,11 @@ impl DictionaryArray { /// Returns an [`DictionaryArray`] whose all elements are null #[inline] - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - let values = Self::try_get_child(&data_type).unwrap(); + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let values = Self::try_get_child(&dtype).unwrap(); let values = new_null_array(values.clone(), 1); Self::try_new( - data_type, + dtype, PrimitiveArray::::new_null(K::PRIMITIVE.into(), length), values, ) @@ -282,20 +282,20 @@ impl DictionaryArray { /// Returns the [`ArrowDataType`] of this [`DictionaryArray`] #[inline] - pub fn data_type(&self) -> &ArrowDataType { - &self.data_type + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype } /// Returns whether the values of this [`DictionaryArray`] are ordered #[inline] pub fn is_ordered(&self) -> bool { - match self.data_type.to_logical_type() { + match self.dtype.to_logical_type() { ArrowDataType::Dictionary(_, _, is_ordered) => *is_ordered, _ => unreachable!(), } } - pub(crate) fn default_data_type(values_datatype: ArrowDataType) -> ArrowDataType { + pub(crate) fn default_dtype(values_datatype: ArrowDataType) -> ArrowDataType { ArrowDataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false) } @@ -395,8 +395,8 @@ impl DictionaryArray { new_scalar(self.values.as_ref(), index) } - pub(crate) fn try_get_child(data_type: &ArrowDataType) -> PolarsResult<&ArrowDataType> { - Ok(match data_type.to_logical_type() { + pub(crate) fn try_get_child(dtype: &ArrowDataType) -> PolarsResult<&ArrowDataType> { + Ok(match dtype.to_logical_type() { ArrowDataType::Dictionary(_, values, _) => values.as_ref(), _ => { polars_bail!(ComputeError: "Dictionaries must be initialized with DataType::Dictionary") @@ -428,12 +428,12 @@ impl Splitable for DictionaryArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), keys: lhs_keys, values: self.values.clone(), }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), keys: rhs_keys, values: self.values.clone(), }, diff --git a/crates/polars-arrow/src/array/dictionary/mutable.rs b/crates/polars-arrow/src/array/dictionary/mutable.rs index d55ba6484443..1d01b1e719f9 100644 --- a/crates/polars-arrow/src/array/dictionary/mutable.rs +++ b/crates/polars-arrow/src/array/dictionary/mutable.rs @@ -13,7 +13,7 @@ use crate::datatypes::ArrowDataType; #[derive(Debug)] pub struct MutableDictionaryArray { - data_type: ArrowDataType, + dtype: ArrowDataType, map: ValueMap, // invariant: `max(keys) < map.values().len()` keys: MutablePrimitiveArray, @@ -24,7 +24,7 @@ impl From> for D // SAFETY: the invariant of this struct ensures that this is up-held unsafe { DictionaryArray::::try_new_unchecked( - other.data_type, + other.dtype, other.keys.into(), other.map.into_values().as_box(), ) @@ -69,10 +69,10 @@ impl MutableDictionaryArray { fn from_value_map(value_map: ValueMap) -> Self { let keys = MutablePrimitiveArray::::new(); - let data_type = - ArrowDataType::Dictionary(K::KEY_TYPE, Box::new(value_map.data_type().clone()), false); + let dtype = + ArrowDataType::Dictionary(K::KEY_TYPE, Box::new(value_map.dtype().clone()), false); Self { - data_type, + dtype, map: value_map, keys, } @@ -134,7 +134,7 @@ impl MutableDictionaryArray { fn take_into(&mut self) -> DictionaryArray { DictionaryArray::::try_new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.keys).into(), self.map.take_into(), ) @@ -159,8 +159,8 @@ impl MutableArray for MutableDictio Arc::new(self.take_into()) } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/array/dictionary/value_map.rs b/crates/polars-arrow/src/array/dictionary/value_map.rs index 5b6bdb9528ba..d818b7e6b25c 100644 --- a/crates/polars-arrow/src/array/dictionary/value_map.rs +++ b/crates/polars-arrow/src/array/dictionary/value_map.rs @@ -1,10 +1,11 @@ use std::borrow::Borrow; use std::fmt::{self, Debug}; -use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; +use std::hash::{BuildHasherDefault, Hash, Hasher}; use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::aliases::PlRandomState; use super::DictionaryKey; use crate::array::indexable::{AsIndexed, Indexable}; @@ -41,11 +42,6 @@ pub struct Hashed { key: K, } -#[inline] -fn ahash_hash(value: &T) -> u64 { - BuildHasherDefault::::default().hash_one(value) -} - impl Hash for Hashed { #[inline] fn hash(&self, state: &mut H) { @@ -57,6 +53,7 @@ impl Hash for Hashed { pub struct ValueMap { pub values: M, pub map: HashMap, (), BuildHasherDefault>, // NB: *only* use insert_hashed_nocheck() and no other hashmap API + random_state: PlRandomState, } impl ValueMap { @@ -67,6 +64,7 @@ impl ValueMap { Ok(Self { values, map: HashMap::default(), + random_state: PlRandomState::default(), }) } @@ -79,11 +77,12 @@ impl ValueMap { values.len(), BuildHasherDefault::::default(), ); + let random_state = PlRandomState::default(); for index in 0..values.len() { let key = K::try_from(index).map_err(|_| polars_err!(ComputeError: "overflow"))?; // SAFETY: we only iterate within bounds let value = unsafe { values.value_unchecked_at(index) }; - let hash = ahash_hash(value.borrow()); + let hash = random_state.hash_one(value.borrow()); let entry = map.raw_entry_mut().from_hash(hash, |item| { // SAFETY: invariant of the struct, it's always in bounds since we maintain it @@ -100,11 +99,15 @@ impl ValueMap { }, } } - Ok(Self { values, map }) + Ok(Self { + values, + map, + random_state, + }) } - pub fn data_type(&self) -> &ArrowDataType { - self.values.data_type() + pub fn dtype(&self) -> &ArrowDataType { + self.values.dtype() } pub fn into_values(self) -> M { @@ -133,7 +136,7 @@ impl ValueMap { V: AsIndexed, M::Type: Eq + Hash, { - let hash = ahash_hash(value.as_indexed()); + let hash = self.random_state.hash_one(value.as_indexed()); let entry = self.map.raw_entry_mut().from_hash(hash, |item| { // SAFETY: we've already checked (the inverse) when we pushed it, so it should be ok? let index = unsafe { item.key.as_usize() }; diff --git a/crates/polars-arrow/src/array/equal/binary.rs b/crates/polars-arrow/src/array/equal/binary.rs index bed8588efb59..93145aa461e2 100644 --- a/crates/polars-arrow/src/array/equal/binary.rs +++ b/crates/polars-arrow/src/array/equal/binary.rs @@ -2,5 +2,5 @@ use crate::array::BinaryArray; use crate::offset::Offset; pub(super) fn equal(lhs: &BinaryArray, rhs: &BinaryArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/binary_view.rs b/crates/polars-arrow/src/array/equal/binary_view.rs index 546e3e2a1818..f413650dc9c3 100644 --- a/crates/polars-arrow/src/array/equal/binary_view.rs +++ b/crates/polars-arrow/src/array/equal/binary_view.rs @@ -5,5 +5,5 @@ pub(super) fn equal( lhs: &BinaryViewArrayGeneric, rhs: &BinaryViewArrayGeneric, ) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/dictionary.rs b/crates/polars-arrow/src/array/equal/dictionary.rs index d65634095fb3..88213cbc059a 100644 --- a/crates/polars-arrow/src/array/equal/dictionary.rs +++ b/crates/polars-arrow/src/array/equal/dictionary.rs @@ -1,7 +1,7 @@ use crate::array::{DictionaryArray, DictionaryKey}; pub(super) fn equal(lhs: &DictionaryArray, rhs: &DictionaryArray) -> bool { - if !(lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()) { + if !(lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len()) { return false; }; diff --git a/crates/polars-arrow/src/array/equal/fixed_size_binary.rs b/crates/polars-arrow/src/array/equal/fixed_size_binary.rs index 883d5739778b..0e956e872090 100644 --- a/crates/polars-arrow/src/array/equal/fixed_size_binary.rs +++ b/crates/polars-arrow/src/array/equal/fixed_size_binary.rs @@ -1,5 +1,5 @@ use crate::array::{Array, FixedSizeBinaryArray}; pub(super) fn equal(lhs: &FixedSizeBinaryArray, rhs: &FixedSizeBinaryArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/fixed_size_list.rs b/crates/polars-arrow/src/array/equal/fixed_size_list.rs index aaf77910013f..26582aa05379 100644 --- a/crates/polars-arrow/src/array/equal/fixed_size_list.rs +++ b/crates/polars-arrow/src/array/equal/fixed_size_list.rs @@ -1,5 +1,5 @@ use crate::array::{Array, FixedSizeListArray}; pub(super) fn equal(lhs: &FixedSizeListArray, rhs: &FixedSizeListArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/list.rs b/crates/polars-arrow/src/array/equal/list.rs index 26faa1598faf..5c08e2103dcb 100644 --- a/crates/polars-arrow/src/array/equal/list.rs +++ b/crates/polars-arrow/src/array/equal/list.rs @@ -2,5 +2,5 @@ use crate::array::{Array, ListArray}; use crate::offset::Offset; pub(super) fn equal(lhs: &ListArray, rhs: &ListArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/map.rs b/crates/polars-arrow/src/array/equal/map.rs index e150fb4a4b41..b98d65cea03a 100644 --- a/crates/polars-arrow/src/array/equal/map.rs +++ b/crates/polars-arrow/src/array/equal/map.rs @@ -1,5 +1,5 @@ use crate::array::{Array, MapArray}; pub(super) fn equal(lhs: &MapArray, rhs: &MapArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/mod.rs b/crates/polars-arrow/src/array/equal/mod.rs index 0a929c793e13..971e4cbca4e8 100644 --- a/crates/polars-arrow/src/array/equal/mod.rs +++ b/crates/polars-arrow/src/array/equal/mod.rs @@ -201,12 +201,12 @@ impl PartialEq<&dyn Array> for MapArray { /// * their data types are equal /// * each of their items are equal pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { - if lhs.data_type() != rhs.data_type() { + if lhs.dtype() != rhs.dtype() { return false; } use crate::datatypes::PhysicalType::*; - match lhs.data_type().to_physical_type() { + match lhs.dtype().to_physical_type() { Null => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); diff --git a/crates/polars-arrow/src/array/equal/primitive.rs b/crates/polars-arrow/src/array/equal/primitive.rs index dc90bb15da5e..375335155dc8 100644 --- a/crates/polars-arrow/src/array/equal/primitive.rs +++ b/crates/polars-arrow/src/array/equal/primitive.rs @@ -2,5 +2,5 @@ use crate::array::PrimitiveArray; use crate::types::NativeType; pub(super) fn equal(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/struct_.rs b/crates/polars-arrow/src/array/equal/struct_.rs index a1741e36368c..3e50626fe7d1 100644 --- a/crates/polars-arrow/src/array/equal/struct_.rs +++ b/crates/polars-arrow/src/array/equal/struct_.rs @@ -1,7 +1,7 @@ use crate::array::{Array, StructArray}; pub(super) fn equal(lhs: &StructArray, rhs: &StructArray) -> bool { - lhs.data_type() == rhs.data_type() + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && match (lhs.validity(), rhs.validity()) { (None, None) => lhs.values().iter().eq(rhs.values().iter()), diff --git a/crates/polars-arrow/src/array/equal/union.rs b/crates/polars-arrow/src/array/equal/union.rs index 51b9d960feac..94881c187fe9 100644 --- a/crates/polars-arrow/src/array/equal/union.rs +++ b/crates/polars-arrow/src/array/equal/union.rs @@ -1,5 +1,5 @@ use crate::array::{Array, UnionArray}; pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/equal/utf8.rs b/crates/polars-arrow/src/array/equal/utf8.rs index 1327221ca331..f76d30a87368 100644 --- a/crates/polars-arrow/src/array/equal/utf8.rs +++ b/crates/polars-arrow/src/array/equal/utf8.rs @@ -2,5 +2,5 @@ use crate::array::Utf8Array; use crate::offset::Offset; pub(super) fn equal(lhs: &Utf8Array, rhs: &Utf8Array) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) + lhs.dtype() == rhs.dtype() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) } diff --git a/crates/polars-arrow/src/array/ffi.rs b/crates/polars-arrow/src/array/ffi.rs index 9806eac25e97..bf9844529b1f 100644 --- a/crates/polars-arrow/src/array/ffi.rs +++ b/crates/polars-arrow/src/array/ffi.rs @@ -54,7 +54,7 @@ type BuffersChildren = ( pub fn offset_buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { use PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => ffi_dyn!(array, NullArray), Boolean => ffi_dyn!(array, BooleanArray), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { diff --git a/crates/polars-arrow/src/array/fixed_size_binary/data.rs b/crates/polars-arrow/src/array/fixed_size_binary/data.rs index f99822eb0fbb..f04be9883f64 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/data.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/data.rs @@ -7,8 +7,8 @@ use crate::datatypes::ArrowDataType; impl Arrow2Arrow for FixedSizeBinaryArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let dtype = self.dtype.clone().into(); + let builder = ArrayDataBuilder::new(dtype) .len(self.len()) .buffers(vec![self.values.clone().into()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); @@ -18,8 +18,8 @@ impl Arrow2Arrow for FixedSizeBinaryArray { } fn from_data(data: &ArrayData) -> Self { - let data_type: ArrowDataType = data.data_type().clone().into(); - let size = match data_type { + let dtype: ArrowDataType = data.data_type().clone().into(); + let size = match dtype { ArrowDataType::FixedSizeBinary(size) => size, _ => unreachable!("must be FixedSizeBinary"), }; @@ -29,7 +29,7 @@ impl Arrow2Arrow for FixedSizeBinaryArray { Self { size, - data_type, + dtype, values, validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), } diff --git a/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs b/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs index 43af7fef58ad..d3d0c777dd66 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/ffi.rs @@ -39,7 +39,7 @@ unsafe impl ToFfi for FixedSizeBinaryArray { Self { size: self.size, - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, values: self.values.clone(), } @@ -48,10 +48,10 @@ unsafe impl ToFfi for FixedSizeBinaryArray { impl FromFfi for FixedSizeBinaryArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let values = unsafe { array.buffer::(1) }?; - Self::try_new(data_type, values, validity) + Self::try_new(dtype, values, validity) } } diff --git a/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs b/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs index c5f9e2dd3293..6aa47acf3fd8 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/fmt.rs @@ -14,7 +14,7 @@ impl Debug for FixedSizeBinaryArray { fn fmt(&self, f: &mut Formatter<'_>) -> Result { let writer = |f: &mut Formatter, index| write_value(self, index, f); - write!(f, "{:?}", self.data_type)?; + write!(f, "{:?}", self.dtype)?; write_vec(f, writer, self.validity(), self.len(), "None", false) } } diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs index 1194b8f5044d..ec3f96626c14 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs @@ -16,8 +16,8 @@ use polars_error::{polars_bail, polars_ensure, PolarsResult}; /// Cloning and slicing this struct is `O(1)`. #[derive(Clone)] pub struct FixedSizeBinaryArray { - size: usize, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. - data_type: ArrowDataType, + size: usize, // this is redundant with `dtype`, but useful to not have to deconstruct the dtype. + dtype: ArrowDataType, values: Buffer, validity: Option, } @@ -27,15 +27,15 @@ impl FixedSizeBinaryArray { /// /// # Errors /// This function returns an error iff: - /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] - /// * The length of `values` is not a multiple of `size` in `data_type` + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `dtype` /// * the validity's length is not equal to `values.len() / size`. pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Buffer, validity: Option, ) -> PolarsResult { - let size = Self::maybe_get_size(&data_type)?; + let size = Self::maybe_get_size(&dtype)?; if values.len() % size != 0 { polars_bail!(ComputeError: @@ -55,7 +55,7 @@ impl FixedSizeBinaryArray { Ok(Self { size, - data_type, + dtype, values, validity, }) @@ -64,23 +64,23 @@ impl FixedSizeBinaryArray { /// Creates a new [`FixedSizeBinaryArray`]. /// # Panics /// This function panics iff: - /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] - /// * The length of `values` is not a multiple of `size` in `data_type` + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `dtype` /// * the validity's length is not equal to `values.len() / size`. - pub fn new(data_type: ArrowDataType, values: Buffer, validity: Option) -> Self { - Self::try_new(data_type, values, validity).unwrap() + pub fn new(dtype: ArrowDataType, values: Buffer, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() } /// Returns a new empty [`FixedSizeBinaryArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - Self::new(data_type, Buffer::new(), None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, Buffer::new(), None) } /// Returns a new null [`FixedSizeBinaryArray`]. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - let size = Self::maybe_get_size(&data_type).unwrap(); + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let size = Self::maybe_get_size(&dtype).unwrap(); Self::new( - data_type, + dtype, vec![0u8; length * size].into(), Some(Bitmap::new_zeroed(length)), ) @@ -178,13 +178,10 @@ impl FixedSizeBinaryArray { /// Returns a new [`FixedSizeBinaryArray`] with a different logical type. /// This is `O(1)`. /// # Panics - /// Panics iff the data_type is not supported for the physical type. + /// Panics iff the dtype is not supported for the physical type. #[inline] - pub fn to(self, data_type: ArrowDataType) -> Self { - match ( - data_type.to_logical_type(), - self.data_type().to_logical_type(), - ) { + pub fn to(self, dtype: ArrowDataType) -> Self { + match (dtype.to_logical_type(), self.dtype().to_logical_type()) { (ArrowDataType::FixedSizeBinary(size_a), ArrowDataType::FixedSizeBinary(size_b)) if size_a == size_b => {}, _ => panic!("Wrong DataType"), @@ -192,7 +189,7 @@ impl FixedSizeBinaryArray { Self { size: self.size, - data_type, + dtype, values: self.values, validity: self.validity, } @@ -205,8 +202,8 @@ impl FixedSizeBinaryArray { } impl FixedSizeBinaryArray { - pub(crate) fn maybe_get_size(data_type: &ArrowDataType) -> PolarsResult { - match data_type.to_logical_type() { + pub(crate) fn maybe_get_size(dtype: &ArrowDataType) -> PolarsResult { + match dtype.to_logical_type() { ArrowDataType::FixedSizeBinary(size) => { polars_ensure!(*size != 0, ComputeError: "FixedSizeBinaryArray expects a positive size"); Ok(*size) @@ -217,8 +214,8 @@ impl FixedSizeBinaryArray { } } - pub fn get_size(data_type: &ArrowDataType) -> usize { - Self::maybe_get_size(data_type).unwrap() + pub fn get_size(dtype: &ArrowDataType) -> usize { + Self::maybe_get_size(dtype).unwrap() } } @@ -248,13 +245,13 @@ impl Splitable for FixedSizeBinaryArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: lhs_values, validity: lhs_validity, size, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: rhs_values, validity: rhs_validity, size, diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs index 8f81ce86f6d8..903c33178640 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs @@ -14,7 +14,7 @@ use crate::datatypes::ArrowDataType; /// This struct does not allocate a validity until one is required (i.e. push a null to it). #[derive(Debug, Clone)] pub struct MutableFixedSizeBinaryArray { - data_type: ArrowDataType, + dtype: ArrowDataType, size: usize, values: Vec, validity: Option, @@ -23,7 +23,7 @@ pub struct MutableFixedSizeBinaryArray { impl From for FixedSizeBinaryArray { fn from(other: MutableFixedSizeBinaryArray) -> Self { FixedSizeBinaryArray::new( - other.data_type, + other.dtype, other.values.into(), other.validity.map(|x| x.into()), ) @@ -35,15 +35,15 @@ impl MutableFixedSizeBinaryArray { /// /// # Errors /// This function returns an error iff: - /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] - /// * The length of `values` is not a multiple of `size` in `data_type` + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `dtype` /// * the validity's length is not equal to `values.len() / size`. pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec, validity: Option, ) -> PolarsResult { - let size = FixedSizeBinaryArray::maybe_get_size(&data_type)?; + let size = FixedSizeBinaryArray::maybe_get_size(&dtype)?; if values.len() % size != 0 { polars_bail!(ComputeError: @@ -63,7 +63,7 @@ impl MutableFixedSizeBinaryArray { Ok(Self { size, - data_type, + dtype, values, validity, }) @@ -114,9 +114,8 @@ impl MutableFixedSizeBinaryArray { } self.values.extend_from_slice(bytes); - match &mut self.validity { - Some(validity) => validity.push(true), - None => {}, + if let Some(validity) = &mut self.validity { + validity.push(true) } }, None => { @@ -265,8 +264,8 @@ impl MutableArray for MutableFixedSizeBinaryArray { .arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/array/fixed_size_list/data.rs b/crates/polars-arrow/src/array/fixed_size_list/data.rs index f98fa452c6ea..de9bc1b882c2 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/data.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/data.rs @@ -6,8 +6,8 @@ use crate::datatypes::ArrowDataType; impl Arrow2Arrow for FixedSizeListArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let dtype = self.dtype.clone().into(); + let builder = ArrayDataBuilder::new(dtype) .len(self.len()) .nulls(self.validity.as_ref().map(|b| b.clone().into())) .child_data(vec![to_data(self.values.as_ref())]); @@ -17,8 +17,8 @@ impl Arrow2Arrow for FixedSizeListArray { } fn from_data(data: &ArrayData) -> Self { - let data_type: ArrowDataType = data.data_type().clone().into(); - let size = match data_type { + let dtype: ArrowDataType = data.data_type().clone().into(); + let size = match dtype { ArrowDataType::FixedSizeList(_, size) => size, _ => unreachable!("must be FixedSizeList type"), }; @@ -28,7 +28,7 @@ impl Arrow2Arrow for FixedSizeListArray { Self { size, - data_type, + dtype, values, validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), } diff --git a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs index 7cb463974e29..29cf7957cf6c 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/ffi.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/ffi.rs @@ -30,12 +30,12 @@ unsafe impl ToFfi for FixedSizeListArray { impl FromFfi for FixedSizeListArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let child = unsafe { array.child(0)? }; let values = ffi::try_from(child)?; - let mut fsl = Self::try_new(data_type, values, validity)?; + let mut fsl = Self::try_new(dtype, values, validity)?; fsl.slice(array.offset(), array.length()); Ok(fsl) } diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 7e512cba5203..3b8b9890d4f4 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -11,13 +11,14 @@ mod iterator; mod mutable; pub use mutable::*; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; /// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. /// Cloning and slicing this struct is `O(1)`. #[derive(Clone)] pub struct FixedSizeListArray { - size: usize, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. - data_type: ArrowDataType, + size: usize, // this is redundant with `dtype`, but useful to not have to deconstruct the dtype. + dtype: ArrowDataType, values: Box, validity: Option, } @@ -27,21 +28,21 @@ impl FixedSizeListArray { /// /// # Errors /// This function returns an error iff: - /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeList`] - /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. - /// * The length of `values` is not a multiple of `size` in `data_type` + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeList`] + /// * The `dtype`'s inner field's data type is not equal to `values.dtype`. + /// * The length of `values` is not a multiple of `size` in `dtype` /// * the validity's length is not equal to `values.len() / size`. pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Box, validity: Option, ) -> PolarsResult { - let (child, size) = Self::try_child_and_size(&data_type)?; + let (child, size) = Self::try_child_and_size(&dtype)?; - let child_data_type = &child.data_type; - let values_data_type = values.data_type(); - if child_data_type != values_data_type { - polars_bail!(ComputeError: "FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_data_type:?} while it got {values_data_type:?}.") + let child_dtype = &child.dtype; + let values_dtype = values.dtype(); + if child_dtype != values_dtype { + polars_bail!(ComputeError: "FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_dtype:?} while it got {values_dtype:?}.") } if values.len() % size != 0 { @@ -62,7 +63,7 @@ impl FixedSizeListArray { Ok(Self { size, - data_type, + dtype, values, validity, }) @@ -70,8 +71,8 @@ impl FixedSizeListArray { /// Alias to `Self::try_new(...).unwrap()` #[track_caller] - pub fn new(data_type: ArrowDataType, values: Box, validity: Option) -> Self { - Self::try_new(data_type, values, validity).unwrap() + pub fn new(dtype: ArrowDataType, values: Box, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() } /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. @@ -80,17 +81,17 @@ impl FixedSizeListArray { } /// Returns a new empty [`FixedSizeListArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - let values = new_empty_array(Self::get_child_and_size(&data_type).0.data_type().clone()); - Self::new(data_type, values, None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + let values = new_empty_array(Self::get_child_and_size(&dtype).0.dtype().clone()); + Self::new(dtype, values, None) } /// Returns a new null [`FixedSizeListArray`]. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - let (field, size) = Self::get_child_and_size(&data_type); + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let (field, size) = Self::get_child_and_size(&dtype); - let values = new_null_array(field.data_type().clone(), length * size); - Self::new(data_type, values, Some(Bitmap::new_zeroed(length))) + let values = new_null_array(field.dtype().clone(), length * size); + Self::new(dtype, values, Some(Bitmap::new_zeroed(length))) } } @@ -181,8 +182,8 @@ impl FixedSizeListArray { } impl FixedSizeListArray { - pub(crate) fn try_child_and_size(data_type: &ArrowDataType) -> PolarsResult<(&Field, usize)> { - match data_type.to_logical_type() { + pub(crate) fn try_child_and_size(dtype: &ArrowDataType) -> PolarsResult<(&Field, usize)> { + match dtype.to_logical_type() { ArrowDataType::FixedSizeList(child, size) => { if *size == 0 { polars_bail!(ComputeError: "FixedSizeBinaryArray expects a positive size") @@ -193,13 +194,13 @@ impl FixedSizeListArray { } } - pub(crate) fn get_child_and_size(data_type: &ArrowDataType) -> (&Field, usize) { - Self::try_child_and_size(data_type).unwrap() + pub(crate) fn get_child_and_size(dtype: &ArrowDataType) -> (&Field, usize) { + Self::try_child_and_size(dtype).unwrap() } /// Returns a [`ArrowDataType`] consistent with [`FixedSizeListArray`]. - pub fn default_datatype(data_type: ArrowDataType, size: usize) -> ArrowDataType { - let field = Box::new(Field::new("item", data_type, true)); + pub fn default_datatype(dtype: ArrowDataType, size: usize) -> ArrowDataType { + let field = Box::new(Field::new(PlSmallStr::from_static("item"), dtype, true)); ArrowDataType::FixedSizeList(field, size) } } @@ -232,13 +233,13 @@ impl Splitable for FixedSizeListArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: lhs_values, validity: lhs_validity, size, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: rhs_values, validity: rhs_validity, size, diff --git a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs index ddd03b9ea099..04802e59bd67 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mutable.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mutable.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use super::FixedSizeListArray; use crate::array::physical_binary::extend_validity; @@ -11,7 +12,7 @@ use crate::datatypes::{ArrowDataType, Field}; /// The mutable version of [`FixedSizeListArray`]. #[derive(Debug, Clone)] pub struct MutableFixedSizeListArray { - data_type: ArrowDataType, + dtype: ArrowDataType, size: usize, values: M, validity: Option, @@ -20,7 +21,7 @@ pub struct MutableFixedSizeListArray { impl From> for FixedSizeListArray { fn from(mut other: MutableFixedSizeListArray) -> Self { FixedSizeListArray::new( - other.data_type, + other.dtype, other.values.as_box(), other.validity.map(|x| x.into()), ) @@ -30,29 +31,29 @@ impl From> for FixedSizeListArray impl MutableFixedSizeListArray { /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. pub fn new(values: M, size: usize) -> Self { - let data_type = FixedSizeListArray::default_datatype(values.data_type().clone(), size); - Self::new_from(values, data_type, size) + let dtype = FixedSizeListArray::default_datatype(values.dtype().clone(), size); + Self::new_from(values, dtype, size) } /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. - pub fn new_with_field(values: M, name: &str, nullable: bool, size: usize) -> Self { - let data_type = ArrowDataType::FixedSizeList( - Box::new(Field::new(name, values.data_type().clone(), nullable)), + pub fn new_with_field(values: M, name: PlSmallStr, nullable: bool, size: usize) -> Self { + let dtype = ArrowDataType::FixedSizeList( + Box::new(Field::new(name, values.dtype().clone(), nullable)), size, ); - Self::new_from(values, data_type, size) + Self::new_from(values, dtype, size) } /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`], [`ArrowDataType`] and size. - pub fn new_from(values: M, data_type: ArrowDataType, size: usize) -> Self { + pub fn new_from(values: M, dtype: ArrowDataType, size: usize) -> Self { assert_eq!(values.len(), 0); - match data_type { + match dtype { ArrowDataType::FixedSizeList(..) => (), - _ => panic!("data type must be FixedSizeList (got {data_type:?})"), + _ => panic!("data type must be FixedSizeList (got {dtype:?})"), }; Self { size, - data_type, + dtype, values, validity: None, } @@ -146,7 +147,7 @@ impl MutableArray for MutableFixedSizeListArray { fn as_box(&mut self) -> Box { FixedSizeListArray::new( - self.data_type.clone(), + self.dtype.clone(), self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), ) @@ -155,15 +156,15 @@ impl MutableArray for MutableFixedSizeListArray { fn as_arc(&mut self) -> Arc { FixedSizeListArray::new( - self.data_type.clone(), + self.dtype.clone(), self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), ) .arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/array/fmt.rs b/crates/polars-arrow/src/array/fmt.rs index 2def0374ab19..6b3fc21752b1 100644 --- a/crates/polars-arrow/src/array/fmt.rs +++ b/crates/polars-arrow/src/array/fmt.rs @@ -12,7 +12,7 @@ pub fn get_value_display<'a, F: Write + 'a>( null: &'static str, ) -> Box Result + 'a> { use crate::datatypes::PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => Box::new(move |f, _| write!(f, "{null}")), Boolean => Box::new(|f, index| { super::boolean::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, f) diff --git a/crates/polars-arrow/src/array/growable/binary.rs b/crates/polars-arrow/src/array/growable/binary.rs index f0b746de2535..44b6ec4da147 100644 --- a/crates/polars-arrow/src/array/growable/binary.rs +++ b/crates/polars-arrow/src/array/growable/binary.rs @@ -13,7 +13,7 @@ use crate::offset::{Offset, Offsets}; /// Concrete [`Growable`] for the [`BinaryArray`]. pub struct GrowableBinary<'a, O: Offset> { arrays: Vec<&'a BinaryArray>, - data_type: ArrowDataType, + dtype: ArrowDataType, validity: Option, values: Vec, offsets: Offsets, @@ -24,7 +24,7 @@ impl<'a, O: Offset> GrowableBinary<'a, O> { /// # Panics /// If `arrays` is empty. pub fn new(arrays: Vec<&'a BinaryArray>, mut use_validity: bool, capacity: usize) -> Self { - let data_type = arrays[0].data_type().clone(); + let dtype = arrays[0].dtype().clone(); // if any of the arrays has nulls, insertions from any array requires setting bits // as there is at least one array with nulls. @@ -34,7 +34,7 @@ impl<'a, O: Offset> GrowableBinary<'a, O> { Self { arrays, - data_type, + dtype, values: Vec::with_capacity(0), offsets: Offsets::with_capacity(capacity), validity: prepare_validity(use_validity, capacity), @@ -42,13 +42,13 @@ impl<'a, O: Offset> GrowableBinary<'a, O> { } fn to(&mut self) -> BinaryArray { - let data_type = self.data_type.clone(); + let dtype = self.dtype.clone(); let validity = std::mem::take(&mut self.validity); let offsets = std::mem::take(&mut self.offsets); let values = std::mem::take(&mut self.values); BinaryArray::::new( - data_type, + dtype, offsets.into(), values.into(), validity.map(|v| v.into()), @@ -96,7 +96,7 @@ impl<'a, O: Offset> Growable<'a> for GrowableBinary<'a, O> { impl<'a, O: Offset> From> for BinaryArray { fn from(val: GrowableBinary<'a, O>) -> Self { BinaryArray::::new( - val.data_type, + val.dtype, val.offsets.into(), val.values.into(), val.validity.map(|v| v.into()), diff --git a/crates/polars-arrow/src/array/growable/binview.rs b/crates/polars-arrow/src/array/growable/binview.rs index 9e4c871d596f..6c974510fc46 100644 --- a/crates/polars-arrow/src/array/growable/binview.rs +++ b/crates/polars-arrow/src/array/growable/binview.rs @@ -1,6 +1,9 @@ use std::ops::Deref; use std::sync::Arc; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; +use polars_utils::itertools::Itertools; + use super::Growable; use crate::array::binview::{BinaryViewArrayGeneric, ViewType}; use crate::array::growable::utils::{extend_validity, extend_validity_copies, prepare_validity}; @@ -8,16 +11,16 @@ use crate::array::{Array, MutableBinaryViewArray, View}; use crate::bitmap::{Bitmap, MutableBitmap}; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; -use crate::legacy::utils::CustomIterTools; /// Concrete [`Growable`] for the [`BinaryArray`]. pub struct GrowableBinaryViewArray<'a, T: ViewType + ?Sized> { arrays: Vec<&'a BinaryViewArrayGeneric>, - data_type: ArrowDataType, + dtype: ArrowDataType, validity: Option, inner: MutableBinaryViewArray, same_buffers: Option<&'a Arc<[Buffer]>>, total_same_buffers_len: usize, // Only valid if same_buffers is Some. + has_duplicate_buffers: bool, } impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { @@ -29,7 +32,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { mut use_validity: bool, capacity: usize, ) -> Self { - let data_type = arrays[0].data_type().clone(); + let dtype = arrays[0].dtype().clone(); // if any of the arrays has nulls, insertions from any array requires setting bits // as there is at least one array with nulls. @@ -51,13 +54,22 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { .then(|| arrays[0].total_buffer_len()) .unwrap_or_default(); + let mut duplicates = PlHashSet::new(); + let mut has_duplicate_buffers = false; + for arr in arrays.iter() { + if !duplicates.insert(arr.data_buffers().as_ptr()) { + has_duplicate_buffers = true; + break; + } + } Self { arrays, - data_type, + dtype, validity: prepare_validity(use_validity, capacity), inner: MutableBinaryViewArray::::with_capacity(capacity), same_buffers, total_same_buffers_len, + has_duplicate_buffers, } } @@ -66,7 +78,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { if let Some(buffers) = self.same_buffers { unsafe { BinaryViewArrayGeneric::::new_unchecked( - self.data_type.clone(), + self.dtype.clone(), arr.views.into(), buffers.clone(), self.validity.take().map(Bitmap::from), @@ -75,7 +87,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { ) } } else { - arr.freeze_with_dtype(self.data_type.clone()) + arr.freeze_with_dtype(self.dtype.clone()) .with_validity(self.validity.take().map(Bitmap::from)) } } @@ -91,15 +103,19 @@ impl<'a, T: ViewType + ?Sized> Growable<'a> for GrowableBinaryViewArray<'a, T> { let range = start..start + len; let views_iter = array.views().get_unchecked(range).iter().cloned(); + if self.same_buffers.is_some() { let mut total_len = 0; self.inner .views .extend(views_iter.inspect(|v| total_len += v.length as usize)); self.inner.total_bytes_len += total_len; + } else if self.has_duplicate_buffers { + self.inner + .extend_non_null_views_unchecked_dedupe(views_iter, local_buffers.deref()); } else { self.inner - .extend_non_null_views_trusted_len_unchecked(views_iter, local_buffers.deref()); + .extend_non_null_views_unchecked(views_iter, local_buffers.deref()); } } diff --git a/crates/polars-arrow/src/array/growable/boolean.rs b/crates/polars-arrow/src/array/growable/boolean.rs index ea18791a804d..622d26493247 100644 --- a/crates/polars-arrow/src/array/growable/boolean.rs +++ b/crates/polars-arrow/src/array/growable/boolean.rs @@ -11,7 +11,7 @@ use crate::datatypes::ArrowDataType; /// Concrete [`Growable`] for the [`BooleanArray`]. pub struct GrowableBoolean<'a> { arrays: Vec<&'a BooleanArray>, - data_type: ArrowDataType, + dtype: ArrowDataType, validity: Option, values: MutableBitmap, } @@ -21,7 +21,7 @@ impl<'a> GrowableBoolean<'a> { /// # Panics /// If `arrays` is empty. pub fn new(arrays: Vec<&'a BooleanArray>, mut use_validity: bool, capacity: usize) -> Self { - let data_type = arrays[0].data_type().clone(); + let dtype = arrays[0].dtype().clone(); // if any of the arrays has nulls, insertions from any array requires setting bits // as there is at least one array with nulls. @@ -31,7 +31,7 @@ impl<'a> GrowableBoolean<'a> { Self { arrays, - data_type, + dtype, values: MutableBitmap::with_capacity(capacity), validity: prepare_validity(use_validity, capacity), } @@ -42,7 +42,7 @@ impl<'a> GrowableBoolean<'a> { let values = std::mem::take(&mut self.values); BooleanArray::new( - self.data_type.clone(), + self.dtype.clone(), values.into(), validity.map(|v| v.into()), ) @@ -87,10 +87,6 @@ impl<'a> Growable<'a> for GrowableBoolean<'a> { impl<'a> From> for BooleanArray { fn from(val: GrowableBoolean<'a>) -> Self { - BooleanArray::new( - val.data_type, - val.values.into(), - val.validity.map(|v| v.into()), - ) + BooleanArray::new(val.dtype, val.values.into(), val.validity.map(|v| v.into())) } } diff --git a/crates/polars-arrow/src/array/growable/dictionary.rs b/crates/polars-arrow/src/array/growable/dictionary.rs index dd2dbc01fde4..f11726be5265 100644 --- a/crates/polars-arrow/src/array/growable/dictionary.rs +++ b/crates/polars-arrow/src/array/growable/dictionary.rs @@ -13,7 +13,7 @@ use crate::datatypes::ArrowDataType; /// This growable does not perform collision checks and instead concatenates /// the values of each [`DictionaryArray`] one after the other. pub struct GrowableDictionary<'a, K: DictionaryKey> { - data_type: ArrowDataType, + dtype: ArrowDataType, keys: Vec<&'a PrimitiveArray>, key_values: Vec, validity: Option, @@ -41,7 +41,7 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { /// # Panics /// If `arrays` is empty. pub fn new(arrays: &[&'a DictionaryArray], mut use_validity: bool, capacity: usize) -> Self { - let data_type = arrays[0].data_type().clone(); + let dtype = arrays[0].dtype().clone(); // if any of the arrays has nulls, insertions from any array requires setting bits // as there is at least one array with nulls. @@ -58,7 +58,7 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { let (values, offsets) = concatenate_values(&arrays_keys, &arrays_values, capacity); Self { - data_type, + dtype, offsets, values, keys: arrays_keys, @@ -84,12 +84,8 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { // SAFETY: the invariant of this struct ensures that this is up-held unsafe { - DictionaryArray::::try_new_unchecked( - self.data_type.clone(), - keys, - self.values.clone(), - ) - .unwrap() + DictionaryArray::::try_new_unchecked(self.dtype.clone(), keys, self.values.clone()) + .unwrap() } } } diff --git a/crates/polars-arrow/src/array/growable/fixed_binary.rs b/crates/polars-arrow/src/array/growable/fixed_binary.rs index 0f52fcd51410..d3e0eae9562b 100644 --- a/crates/polars-arrow/src/array/growable/fixed_binary.rs +++ b/crates/polars-arrow/src/array/growable/fixed_binary.rs @@ -30,7 +30,7 @@ impl<'a> GrowableFixedSizeBinary<'a> { use_validity = true; }; - let size = FixedSizeBinaryArray::get_size(arrays[0].data_type()); + let size = FixedSizeBinaryArray::get_size(arrays[0].dtype()); Self { arrays, values: Vec::with_capacity(0), @@ -44,7 +44,7 @@ impl<'a> GrowableFixedSizeBinary<'a> { let values = std::mem::take(&mut self.values); FixedSizeBinaryArray::new( - self.arrays[0].data_type().clone(), + self.arrays[0].dtype().clone(), values.into(), validity.map(|v| v.into()), ) @@ -88,7 +88,7 @@ impl<'a> Growable<'a> for GrowableFixedSizeBinary<'a> { impl<'a> From> for FixedSizeBinaryArray { fn from(val: GrowableFixedSizeBinary<'a>) -> Self { FixedSizeBinaryArray::new( - val.arrays[0].data_type().clone(), + val.arrays[0].dtype().clone(), val.values.into(), val.validity.map(|v| v.into()), ) diff --git a/crates/polars-arrow/src/array/growable/fixed_size_list.rs b/crates/polars-arrow/src/array/growable/fixed_size_list.rs index 1841285f377d..c15202084006 100644 --- a/crates/polars-arrow/src/array/growable/fixed_size_list.rs +++ b/crates/polars-arrow/src/array/growable/fixed_size_list.rs @@ -33,13 +33,12 @@ impl<'a> GrowableFixedSizeList<'a> { use_validity = true; }; - let size = if let ArrowDataType::FixedSizeList(_, size) = - &arrays[0].data_type().to_logical_type() - { - *size - } else { - unreachable!("`GrowableFixedSizeList` expects `DataType::FixedSizeList`") - }; + let size = + if let ArrowDataType::FixedSizeList(_, size) = &arrays[0].dtype().to_logical_type() { + *size + } else { + unreachable!("`GrowableFixedSizeList` expects `DataType::FixedSizeList`") + }; let inner = arrays .iter() @@ -60,7 +59,7 @@ impl<'a> GrowableFixedSizeList<'a> { let values = self.values.as_box(); FixedSizeListArray::new( - self.arrays[0].data_type().clone(), + self.arrays[0].dtype().clone(), values, validity.map(|v| v.into()), ) @@ -111,7 +110,7 @@ impl<'a> From> for FixedSizeListArray { let values = values.as_box(); Self::new( - val.arrays[0].data_type().clone(), + val.arrays[0].dtype().clone(), values, val.validity.map(|v| v.into()), ) diff --git a/crates/polars-arrow/src/array/growable/list.rs b/crates/polars-arrow/src/array/growable/list.rs index a97518a310e3..90e4f15020a6 100644 --- a/crates/polars-arrow/src/array/growable/list.rs +++ b/crates/polars-arrow/src/array/growable/list.rs @@ -70,7 +70,7 @@ impl<'a, O: Offset> GrowableList<'a, O> { let values = self.values.as_box(); ListArray::::new( - self.arrays[0].data_type().clone(), + self.arrays[0].dtype().clone(), offsets.into(), values, validity.map(|v| v.into()), diff --git a/crates/polars-arrow/src/array/growable/mod.rs b/crates/polars-arrow/src/array/growable/mod.rs index ada21b71c121..1238b29f59a3 100644 --- a/crates/polars-arrow/src/array/growable/mod.rs +++ b/crates/polars-arrow/src/array/growable/mod.rs @@ -27,6 +27,7 @@ pub use dictionary::GrowableDictionary; mod binview; pub use binview::GrowableBinaryViewArray; + mod utils; /// Describes a struct that can be extended from slices of other pre-existing [`Array`]s. @@ -91,17 +92,15 @@ pub fn make_growable<'a>( capacity: usize, ) -> Box + 'a> { assert!(!arrays.is_empty()); - let data_type = arrays[0].data_type(); + let dtype = arrays[0].dtype(); use PhysicalType::*; - match data_type.to_physical_type() { - Null => Box::new(null::GrowableNull::new(data_type.clone())), + match dtype.to_physical_type() { + Null => Box::new(null::GrowableNull::new(dtype.clone())), Boolean => dyn_growable!(boolean::GrowableBoolean, arrays, use_validity, capacity), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { dyn_growable!(primitive::GrowablePrimitive::<$T>, arrays, use_validity, capacity) }), - Utf8 => dyn_growable!(utf8::GrowableUtf8::, arrays, use_validity, capacity), - LargeUtf8 => dyn_growable!(utf8::GrowableUtf8::, arrays, use_validity, capacity), Binary => dyn_growable!( binary::GrowableBinary::, arrays, @@ -120,7 +119,6 @@ pub fn make_growable<'a>( use_validity, capacity ), - List => dyn_growable!(list::GrowableList::, arrays, use_validity, capacity), LargeList => dyn_growable!(list::GrowableList::, arrays, use_validity, capacity), Struct => dyn_growable!(structure::GrowableStruct, arrays, use_validity, capacity), FixedSizeList => dyn_growable!( @@ -163,6 +161,6 @@ pub fn make_growable<'a>( )) }) }, - Union | Map => unimplemented!(), + Union | Map | Utf8 | LargeUtf8 | List => unimplemented!(), } } diff --git a/crates/polars-arrow/src/array/growable/null.rs b/crates/polars-arrow/src/array/growable/null.rs index 155f90d190aa..c0b92e132819 100644 --- a/crates/polars-arrow/src/array/growable/null.rs +++ b/crates/polars-arrow/src/array/growable/null.rs @@ -6,7 +6,7 @@ use crate::datatypes::ArrowDataType; /// Concrete [`Growable`] for the [`NullArray`]. pub struct GrowableNull { - data_type: ArrowDataType, + dtype: ArrowDataType, length: usize, } @@ -18,11 +18,8 @@ impl Default for GrowableNull { impl GrowableNull { /// Creates a new [`GrowableNull`]. - pub fn new(data_type: ArrowDataType) -> Self { - Self { - data_type, - length: 0, - } + pub fn new(dtype: ArrowDataType) -> Self { + Self { dtype, length: 0 } } } @@ -41,16 +38,16 @@ impl<'a> Growable<'a> for GrowableNull { } fn as_arc(&mut self) -> Arc { - Arc::new(NullArray::new(self.data_type.clone(), self.length)) + Arc::new(NullArray::new(self.dtype.clone(), self.length)) } fn as_box(&mut self) -> Box { - Box::new(NullArray::new(self.data_type.clone(), self.length)) + Box::new(NullArray::new(self.dtype.clone(), self.length)) } } impl From for NullArray { fn from(val: GrowableNull) -> Self { - NullArray::new(val.data_type, val.length) + NullArray::new(val.dtype, val.length) } } diff --git a/crates/polars-arrow/src/array/growable/primitive.rs b/crates/polars-arrow/src/array/growable/primitive.rs index 936905ab05fa..715ce0f82ad7 100644 --- a/crates/polars-arrow/src/array/growable/primitive.rs +++ b/crates/polars-arrow/src/array/growable/primitive.rs @@ -11,7 +11,7 @@ use crate::types::NativeType; /// Concrete [`Growable`] for the [`PrimitiveArray`]. pub struct GrowablePrimitive<'a, T: NativeType> { - data_type: ArrowDataType, + dtype: ArrowDataType, arrays: Vec<&'a PrimitiveArray>, validity: Option, values: Vec, @@ -32,10 +32,10 @@ impl<'a, T: NativeType> GrowablePrimitive<'a, T> { use_validity = true; }; - let data_type = arrays[0].data_type().clone(); + let dtype = arrays[0].dtype().clone(); Self { - data_type, + dtype, arrays, values: Vec::with_capacity(capacity), validity: prepare_validity(use_validity, capacity), @@ -48,7 +48,7 @@ impl<'a, T: NativeType> GrowablePrimitive<'a, T> { let values = std::mem::take(&mut self.values); PrimitiveArray::::new( - self.data_type.clone(), + self.dtype.clone(), values.into(), validity.map(|v| v.into()), ) @@ -107,10 +107,6 @@ impl<'a, T: NativeType> Growable<'a> for GrowablePrimitive<'a, T> { impl<'a, T: NativeType> From> for PrimitiveArray { #[inline] fn from(val: GrowablePrimitive<'a, T>) -> Self { - PrimitiveArray::::new( - val.data_type, - val.values.into(), - val.validity.map(|v| v.into()), - ) + PrimitiveArray::::new(val.dtype, val.values.into(), val.validity.map(|v| v.into())) } } diff --git a/crates/polars-arrow/src/array/growable/structure.rs b/crates/polars-arrow/src/array/growable/structure.rs index a27a9cfe6bee..5f3d0c107c62 100644 --- a/crates/polars-arrow/src/array/growable/structure.rs +++ b/crates/polars-arrow/src/array/growable/structure.rs @@ -59,7 +59,7 @@ impl<'a> GrowableStruct<'a> { let values = values.into_iter().map(|mut x| x.as_box()).collect(); StructArray::new( - self.arrays[0].data_type().clone(), + self.arrays[0].dtype().clone(), values, validity.map(|v| v.into()), ) @@ -122,7 +122,7 @@ impl<'a> From> for StructArray { let values = val.values.into_iter().map(|mut x| x.as_box()).collect(); StructArray::new( - val.arrays[0].data_type().clone(), + val.arrays[0].dtype().clone(), values, val.validity.map(|v| v.into()), ) diff --git a/crates/polars-arrow/src/array/growable/utf8.rs b/crates/polars-arrow/src/array/growable/utf8.rs index f4e4e762fc67..4fc1c415d74e 100644 --- a/crates/polars-arrow/src/array/growable/utf8.rs +++ b/crates/polars-arrow/src/array/growable/utf8.rs @@ -48,7 +48,7 @@ impl<'a, O: Offset> GrowableUtf8<'a, O> { unsafe { Utf8Array::::new_unchecked( - self.arrays[0].data_type().clone(), + self.arrays[0].dtype().clone(), offsets.into(), values.into(), validity.map(|v| v.into()), diff --git a/crates/polars-arrow/src/array/list/data.rs b/crates/polars-arrow/src/array/list/data.rs index 212778a05abb..0d28583df125 100644 --- a/crates/polars-arrow/src/array/list/data.rs +++ b/crates/polars-arrow/src/array/list/data.rs @@ -6,9 +6,9 @@ use crate::offset::{Offset, OffsetsBuffer}; impl Arrow2Arrow for ListArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); + let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let builder = ArrayDataBuilder::new(dtype) .len(self.len()) .buffers(vec![self.offsets.clone().into_inner().into()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())) @@ -19,17 +19,17 @@ impl Arrow2Arrow for ListArray { } fn from_data(data: &ArrayData) -> Self { - let data_type = data.data_type().clone().into(); + let dtype = data.data_type().clone().into(); if data.is_empty() { // Handle empty offsets - return Self::new_empty(data_type); + return Self::new_empty(dtype); } let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; offsets.slice(data.offset(), data.len() + 1); Self { - data_type, + dtype, offsets, values: from_data(&data.child_data()[0]), validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), diff --git a/crates/polars-arrow/src/array/list/ffi.rs b/crates/polars-arrow/src/array/list/ffi.rs index e536a713cbc2..2ac23e45635b 100644 --- a/crates/polars-arrow/src/array/list/ffi.rs +++ b/crates/polars-arrow/src/array/list/ffi.rs @@ -45,7 +45,7 @@ unsafe impl ToFfi for ListArray { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, offsets: self.offsets.clone(), values: self.values.clone(), @@ -55,7 +55,7 @@ unsafe impl ToFfi for ListArray { impl FromFfi for ListArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let offsets = unsafe { array.buffer::(1) }?; let child = unsafe { array.child(0)? }; @@ -64,6 +64,6 @@ impl FromFfi for ListArray { // assumption that data from FFI is well constructed let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; - Self::try_new(data_type, offsets, values, validity) + Self::try_new(dtype, offsets, values, validity) } } diff --git a/crates/polars-arrow/src/array/list/mod.rs b/crates/polars-arrow/src/array/list/mod.rs index 27c20b72d0ea..3c2bb6b41f98 100644 --- a/crates/polars-arrow/src/array/list/mod.rs +++ b/crates/polars-arrow/src/array/list/mod.rs @@ -13,11 +13,12 @@ pub use iterator::*; mod mutable; pub use mutable::*; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; /// An [`Array`] semantically equivalent to `Vec>>>` with Arrow's in-memory. #[derive(Clone)] pub struct ListArray { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Box, validity: Option, @@ -30,12 +31,12 @@ impl ListArray { /// This function returns an error iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. - /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `dtype`'s inner field's data type is not equal to `values.dtype`. /// # Implementation /// This function is `O(1)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Box, validity: Option, @@ -49,14 +50,14 @@ impl ListArray { polars_bail!(ComputeError: "validity mask length must match the number of values") } - let child_data_type = Self::try_get_child(&data_type)?.data_type(); - let values_data_type = values.data_type(); - if child_data_type != values_data_type { - polars_bail!(ComputeError: "ListArray's child's DataType must match. However, the expected DataType is {child_data_type:?} while it got {values_data_type:?}."); + let child_dtype = Self::try_get_child(&dtype)?.dtype(); + let values_dtype = values.dtype(); + if child_dtype != values_dtype { + polars_bail!(ComputeError: "ListArray's child's DataType must match. However, the expected DataType is {child_dtype:?} while it got {values_dtype:?}."); } Ok(Self { - data_type, + dtype, offsets, values, validity, @@ -69,31 +70,31 @@ impl ListArray { /// This function panics iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. - /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `dtype`'s inner field's data type is not equal to `values.dtype`. /// # Implementation /// This function is `O(1)` pub fn new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Box, validity: Option, ) -> Self { - Self::try_new(data_type, offsets, values, validity).unwrap() + Self::try_new(dtype, offsets, values, validity).unwrap() } /// Returns a new empty [`ListArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - let values = new_empty_array(Self::get_child_type(&data_type).clone()); - Self::new(data_type, OffsetsBuffer::default(), values, None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + let values = new_empty_array(Self::get_child_type(&dtype).clone()); + Self::new(dtype, OffsetsBuffer::default(), values, None) } /// Returns a new null [`ListArray`]. #[inline] - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - let child = Self::get_child_type(&data_type).clone(); + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let child = Self::get_child_type(&dtype).clone(); Self::new( - data_type, + dtype, Offsets::new_zeroed(length).into(), new_empty_array(child), Some(Bitmap::new_zeroed(length)), @@ -184,8 +185,8 @@ impl ListArray { impl ListArray { /// Returns a default [`ArrowDataType`]: inner field is named "item" and is nullable - pub fn default_datatype(data_type: ArrowDataType) -> ArrowDataType { - let field = Box::new(Field::new("item", data_type, true)); + pub fn default_datatype(dtype: ArrowDataType) -> ArrowDataType { + let field = Box::new(Field::new(PlSmallStr::from_static("item"), dtype, true)); if O::IS_LARGE { ArrowDataType::LargeList(field) } else { @@ -196,21 +197,21 @@ impl ListArray { /// Returns a the inner [`Field`] /// # Panics /// Panics iff the logical type is not consistent with this struct. - pub fn get_child_field(data_type: &ArrowDataType) -> &Field { - Self::try_get_child(data_type).unwrap() + pub fn get_child_field(dtype: &ArrowDataType) -> &Field { + Self::try_get_child(dtype).unwrap() } /// Returns a the inner [`Field`] /// # Errors /// Panics iff the logical type is not consistent with this struct. - pub fn try_get_child(data_type: &ArrowDataType) -> PolarsResult<&Field> { + pub fn try_get_child(dtype: &ArrowDataType) -> PolarsResult<&Field> { if O::IS_LARGE { - match data_type.to_logical_type() { + match dtype.to_logical_type() { ArrowDataType::LargeList(child) => Ok(child.as_ref()), _ => polars_bail!(ComputeError: "ListArray expects DataType::LargeList"), } } else { - match data_type.to_logical_type() { + match dtype.to_logical_type() { ArrowDataType::List(child) => Ok(child.as_ref()), _ => polars_bail!(ComputeError: "ListArray expects DataType::List"), } @@ -220,8 +221,8 @@ impl ListArray { /// Returns a the inner [`ArrowDataType`] /// # Panics /// Panics iff the logical type is not consistent with this struct. - pub fn get_child_type(data_type: &ArrowDataType) -> &ArrowDataType { - Self::get_child_field(data_type).data_type() + pub fn get_child_type(dtype: &ArrowDataType) -> &ArrowDataType { + Self::get_child_field(dtype).dtype() } } @@ -249,13 +250,13 @@ impl Splitable for ListArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: lhs_offsets, validity: lhs_validity, values: self.values.clone(), }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: rhs_offsets, validity: rhs_validity, values: self.values.clone(), diff --git a/crates/polars-arrow/src/array/list/mutable.rs b/crates/polars-arrow/src/array/list/mutable.rs index 3fd528019063..a52e095d72bb 100644 --- a/crates/polars-arrow/src/array/list/mutable.rs +++ b/crates/polars-arrow/src/array/list/mutable.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use polars_error::{polars_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use super::ListArray; use crate::array::physical_binary::extend_validity; @@ -13,7 +14,7 @@ use crate::trusted_len::TrustedLen; /// The mutable version of [`ListArray`]. #[derive(Debug, Clone)] pub struct MutableListArray { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: M, validity: Option, @@ -23,18 +24,18 @@ impl MutableListArray { /// Creates a new empty [`MutableListArray`]. pub fn new() -> Self { let values = M::default(); - let data_type = ListArray::::default_datatype(values.data_type().clone()); - Self::new_from(values, data_type, 0) + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Self::new_from(values, dtype, 0) } /// Creates a new [`MutableListArray`] with a capacity. pub fn with_capacity(capacity: usize) -> Self { let values = M::default(); - let data_type = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); let offsets = Offsets::::with_capacity(capacity); Self { - data_type, + dtype, offsets, values, validity: None, @@ -51,7 +52,7 @@ impl Default for MutableListArray { impl From> for ListArray { fn from(mut other: MutableListArray) -> Self { ListArray::new( - other.data_type, + other.dtype, other.offsets.into(), other.values.as_box(), other.validity.map(|x| x.into()), @@ -109,12 +110,12 @@ where impl MutableListArray { /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. - pub fn new_from(values: M, data_type: ArrowDataType, capacity: usize) -> Self { + pub fn new_from(values: M, dtype: ArrowDataType, capacity: usize) -> Self { let offsets = Offsets::::with_capacity(capacity); assert_eq!(values.len(), 0); - ListArray::::get_child_field(&data_type); + ListArray::::get_child_field(&dtype); Self { - data_type, + dtype, offsets, values, validity: None, @@ -122,20 +123,20 @@ impl MutableListArray { } /// Creates a new [`MutableListArray`] from a [`MutableArray`]. - pub fn new_with_field(values: M, name: &str, nullable: bool) -> Self { - let field = Box::new(Field::new(name, values.data_type().clone(), nullable)); - let data_type = if O::IS_LARGE { + pub fn new_with_field(values: M, name: PlSmallStr, nullable: bool) -> Self { + let field = Box::new(Field::new(name, values.dtype().clone(), nullable)); + let dtype = if O::IS_LARGE { ArrowDataType::LargeList(field) } else { ArrowDataType::List(field) }; - Self::new_from(values, data_type, 0) + Self::new_from(values, dtype, 0) } /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. pub fn new_with_capacity(values: M, capacity: usize) -> Self { - let data_type = ListArray::::default_datatype(values.data_type().clone()); - Self::new_from(values, data_type, capacity) + let dtype = ListArray::::default_datatype(values.dtype().clone()); + Self::new_from(values, dtype, capacity) } /// Creates a new [`MutableListArray`] from a [`MutableArray`], [`Offsets`] and @@ -146,9 +147,9 @@ impl MutableListArray { validity: Option, ) -> Self { assert_eq!(values.len(), offsets.last().to_usize()); - let data_type = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); Self { - data_type, + dtype, offsets, values, validity, @@ -273,7 +274,7 @@ impl MutableArray for MutableListArray Box { ListArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.offsets).into(), self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), @@ -283,7 +284,7 @@ impl MutableArray for MutableListArray Arc { ListArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.offsets).into(), self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), @@ -291,8 +292,8 @@ impl MutableArray for MutableListArray &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/array/map/data.rs b/crates/polars-arrow/src/array/map/data.rs index 8eb586e05f4c..b5530886d817 100644 --- a/crates/polars-arrow/src/array/map/data.rs +++ b/crates/polars-arrow/src/array/map/data.rs @@ -6,9 +6,9 @@ use crate::offset::OffsetsBuffer; impl Arrow2Arrow for MapArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); + let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let builder = ArrayDataBuilder::new(dtype) .len(self.len()) .buffers(vec![self.offsets.clone().into_inner().into()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())) @@ -19,17 +19,17 @@ impl Arrow2Arrow for MapArray { } fn from_data(data: &ArrayData) -> Self { - let data_type = data.data_type().clone().into(); + let dtype = data.data_type().clone().into(); if data.is_empty() { // Handle empty offsets - return Self::new_empty(data_type); + return Self::new_empty(dtype); } let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; offsets.slice(data.offset(), data.len() + 1); Self { - data_type: data.data_type().clone().into(), + dtype: data.data_type().clone().into(), offsets, field: from_data(&data.child_data()[0]), validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), diff --git a/crates/polars-arrow/src/array/map/ffi.rs b/crates/polars-arrow/src/array/map/ffi.rs index fad531671703..2233b371f7eb 100644 --- a/crates/polars-arrow/src/array/map/ffi.rs +++ b/crates/polars-arrow/src/array/map/ffi.rs @@ -45,7 +45,7 @@ unsafe impl ToFfi for MapArray { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, offsets: self.offsets.clone(), field: self.field.clone(), @@ -55,7 +55,7 @@ unsafe impl ToFfi for MapArray { impl FromFfi for MapArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let offsets = unsafe { array.buffer::(1) }?; let child = array.child(0)?; @@ -64,6 +64,6 @@ impl FromFfi for MapArray { // assumption that data from FFI is well constructed let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; - Self::try_new(data_type, offsets, values, validity) + Self::try_new(dtype, offsets, values, validity) } } diff --git a/crates/polars-arrow/src/array/map/mod.rs b/crates/polars-arrow/src/array/map/mod.rs index 219d703329e3..5497c1d7342b 100644 --- a/crates/polars-arrow/src/array/map/mod.rs +++ b/crates/polars-arrow/src/array/map/mod.rs @@ -15,7 +15,7 @@ use polars_error::{polars_bail, PolarsResult}; /// An array representing a (key, value), both of arbitrary logical types. #[derive(Clone)] pub struct MapArray { - data_type: ArrowDataType, + dtype: ArrowDataType, // invariant: field.len() == offsets.len() offsets: OffsetsBuffer, field: Box, @@ -28,27 +28,27 @@ impl MapArray { /// # Errors /// This function errors iff: /// * The last offset is not equal to the field' length - /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`] - /// * The fields' `data_type` is not equal to the inner field of `data_type` + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Map`] + /// * The fields' `dtype` is not equal to the inner field of `dtype` /// * The validity is not `None` and its length is different from `offsets.len() - 1`. pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, field: Box, validity: Option, ) -> PolarsResult { try_check_offsets_bounds(&offsets, field.len())?; - let inner_field = Self::try_get_field(&data_type)?; - if let ArrowDataType::Struct(inner) = inner_field.data_type() { + let inner_field = Self::try_get_field(&dtype)?; + if let ArrowDataType::Struct(inner) = inner_field.dtype() { if inner.len() != 2 { polars_bail!(ComputeError: "MapArray's inner `Struct` must have 2 fields (keys and maps)") } } else { polars_bail!(ComputeError: "MapArray expects `DataType::Struct` as its inner logical type") } - if field.data_type() != inner_field.data_type() { - polars_bail!(ComputeError: "MapArray expects `field.data_type` to match its inner DataType") + if field.dtype() != inner_field.dtype() { + polars_bail!(ComputeError: "MapArray expects `field.dtype` to match its inner DataType") } if validity @@ -59,7 +59,7 @@ impl MapArray { } Ok(Self { - data_type, + dtype, field, offsets, validity, @@ -69,22 +69,22 @@ impl MapArray { /// Creates a new [`MapArray`]. /// # Panics /// * The last offset is not equal to the field' length. - /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`], + /// * The `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Map`], /// * The validity is not `None` and its length is different from `offsets.len() - 1`. pub fn new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, field: Box, validity: Option, ) -> Self { - Self::try_new(data_type, offsets, field, validity).unwrap() + Self::try_new(dtype, offsets, field, validity).unwrap() } /// Returns a new null [`MapArray`] of `length`. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - let field = new_empty_array(Self::get_field(&data_type).data_type().clone()); + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + let field = new_empty_array(Self::get_field(&dtype).dtype().clone()); Self::new( - data_type, + dtype, vec![0i32; 1 + length].try_into().unwrap(), field, Some(Bitmap::new_zeroed(length)), @@ -92,9 +92,9 @@ impl MapArray { } /// Returns a new empty [`MapArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - let field = new_empty_array(Self::get_field(&data_type).data_type().clone()); - Self::new(data_type, OffsetsBuffer::default(), field, None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + let field = new_empty_array(Self::get_field(&dtype).dtype().clone()); + Self::new(dtype, OffsetsBuffer::default(), field, None) } } @@ -128,16 +128,16 @@ impl MapArray { impl_mut_validity!(); impl_into_array!(); - pub(crate) fn try_get_field(data_type: &ArrowDataType) -> PolarsResult<&Field> { - if let ArrowDataType::Map(field, _) = data_type.to_logical_type() { + pub(crate) fn try_get_field(dtype: &ArrowDataType) -> PolarsResult<&Field> { + if let ArrowDataType::Map(field, _) = dtype.to_logical_type() { Ok(field.as_ref()) } else { - polars_bail!(ComputeError: "The data_type's logical type must be DataType::Map") + polars_bail!(ComputeError: "The dtype's logical type must be DataType::Map") } } - pub(crate) fn get_field(data_type: &ArrowDataType) -> &Field { - Self::try_get_field(data_type).unwrap() + pub(crate) fn get_field(dtype: &ArrowDataType) -> &Field { + Self::try_get_field(dtype).unwrap() } } @@ -207,13 +207,13 @@ impl Splitable for MapArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: lhs_offsets, field: self.field.clone(), validity: lhs_validity, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: rhs_offsets, field: self.field.clone(), validity: rhs_validity, diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs index c2c0c958032d..34b891ccbd1a 100644 --- a/crates/polars-arrow/src/array/mod.rs +++ b/crates/polars-arrow/src/array/mod.rs @@ -12,7 +12,7 @@ //! * [`StructArray`] and [`MutableStructArray`], an array of arrays identified by a string (e.g. `{"a": [1, 2], "b": [true, false]}`) //! //! All immutable arrays implement the trait object [`Array`] and that can be downcasted -//! to a concrete struct based on [`PhysicalType`](crate::datatypes::PhysicalType) available from [`Array::data_type`]. +//! to a concrete struct based on [`PhysicalType`](crate::datatypes::PhysicalType) available from [`Array::dtype`]. //! All immutable arrays are backed by [`Buffer`](crate::buffer::Buffer) and thus cloning and slicing them is `O(1)`. //! //! Most arrays contain a [`MutableArray`] counterpart that is neither clonable nor sliceable, but @@ -58,7 +58,7 @@ pub trait Splitable: Sized { } /// A trait representing an immutable Arrow array. Arrow arrays are trait objects -/// that are infallibly downcasted to concrete types according to the [`Array::data_type`]. +/// that are infallibly downcasted to concrete types according to the [`Array::dtype`]. pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// Converts itself to a reference of [`Any`], which enables downcasting to concrete types. fn as_any(&self) -> &dyn Any; @@ -77,7 +77,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// The [`ArrowDataType`] of the [`Array`]. In combination with [`Array::as_any`], this can be /// used to downcast trait objects (`dyn Array`) to concrete arrays. - fn data_type(&self) -> &ArrowDataType; + fn dtype(&self) -> &ArrowDataType; /// The validity of the [`Array`]: every array has an optional [`Bitmap`] that, when available /// specifies whether the array slot is valid or not (null). @@ -89,7 +89,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// This is `O(1)` since the number of null elements is pre-computed. #[inline] fn null_count(&self) -> usize { - if self.data_type() == &ArrowDataType::Null { + if self.dtype() == &ArrowDataType::Null { return self.len(); }; self.validity() @@ -162,7 +162,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { #[must_use] fn sliced(&self, offset: usize, length: usize) -> Box { if length == 0 { - return new_empty_array(self.data_type().clone()); + return new_empty_array(self.dtype().clone()); } let mut new = self.to_boxed(); new.slice(offset, length); @@ -195,12 +195,13 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { dyn_clone::clone_trait_object!(Array); /// A trait describing a mutable array; i.e. an array whose values can be changed. +/// /// Mutable arrays cannot be cloned but can be mutated in place, /// thereby making them useful to perform numeric operations without allocations. /// As in [`Array`], concrete arrays (such as [`MutablePrimitiveArray`]) implement how they are mutated. pub trait MutableArray: std::fmt::Debug + Send + Sync { /// The [`ArrowDataType`] of the array. - fn data_type(&self) -> &ArrowDataType; + fn dtype(&self) -> &ArrowDataType; /// The length of the array. fn len(&self) -> usize; @@ -268,8 +269,8 @@ impl MutableArray for Box { self.as_mut().as_arc() } - fn data_type(&self) -> &ArrowDataType { - self.as_ref().data_type() + fn dtype(&self) -> &ArrowDataType { + self.as_ref().dtype() } fn as_any(&self) -> &dyn std::any::Any { @@ -311,7 +312,7 @@ macro_rules! fmt_dyn { impl std::fmt::Debug for dyn Array + '_ { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use crate::datatypes::PhysicalType::*; - match self.data_type().to_physical_type() { + match self.dtype().to_physical_type() { Null => fmt_dyn!(self, NullArray, f), Boolean => fmt_dyn!(self, BooleanArray, f), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { @@ -340,62 +341,63 @@ impl std::fmt::Debug for dyn Array + '_ { } /// Creates a new [`Array`] with a [`Array::len`] of 0. -pub fn new_empty_array(data_type: ArrowDataType) -> Box { +pub fn new_empty_array(dtype: ArrowDataType) -> Box { use crate::datatypes::PhysicalType::*; - match data_type.to_physical_type() { - Null => Box::new(NullArray::new_empty(data_type)), - Boolean => Box::new(BooleanArray::new_empty(data_type)), + match dtype.to_physical_type() { + Null => Box::new(NullArray::new_empty(dtype)), + Boolean => Box::new(BooleanArray::new_empty(dtype)), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - Box::new(PrimitiveArray::<$T>::new_empty(data_type)) + Box::new(PrimitiveArray::<$T>::new_empty(dtype)) }), - Binary => Box::new(BinaryArray::::new_empty(data_type)), - LargeBinary => Box::new(BinaryArray::::new_empty(data_type)), - FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_empty(data_type)), - Utf8 => Box::new(Utf8Array::::new_empty(data_type)), - LargeUtf8 => Box::new(Utf8Array::::new_empty(data_type)), - List => Box::new(ListArray::::new_empty(data_type)), - LargeList => Box::new(ListArray::::new_empty(data_type)), - FixedSizeList => Box::new(FixedSizeListArray::new_empty(data_type)), - Struct => Box::new(StructArray::new_empty(data_type)), - Union => Box::new(UnionArray::new_empty(data_type)), - Map => Box::new(MapArray::new_empty(data_type)), - Utf8View => Box::new(Utf8ViewArray::new_empty(data_type)), - BinaryView => Box::new(BinaryViewArray::new_empty(data_type)), + Binary => Box::new(BinaryArray::::new_empty(dtype)), + LargeBinary => Box::new(BinaryArray::::new_empty(dtype)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_empty(dtype)), + Utf8 => Box::new(Utf8Array::::new_empty(dtype)), + LargeUtf8 => Box::new(Utf8Array::::new_empty(dtype)), + List => Box::new(ListArray::::new_empty(dtype)), + LargeList => Box::new(ListArray::::new_empty(dtype)), + FixedSizeList => Box::new(FixedSizeListArray::new_empty(dtype)), + Struct => Box::new(StructArray::new_empty(dtype)), + Union => Box::new(UnionArray::new_empty(dtype)), + Map => Box::new(MapArray::new_empty(dtype)), + Utf8View => Box::new(Utf8ViewArray::new_empty(dtype)), + BinaryView => Box::new(BinaryViewArray::new_empty(dtype)), Dictionary(key_type) => { match_integer_type!(key_type, |$T| { - Box::new(DictionaryArray::<$T>::new_empty(data_type)) + Box::new(DictionaryArray::<$T>::new_empty(dtype)) }) }, } } -/// Creates a new [`Array`] of [`ArrowDataType`] `data_type` and `length`. +/// Creates a new [`Array`] of [`ArrowDataType`] `dtype` and `length`. +/// /// The array is guaranteed to have [`Array::null_count`] equal to [`Array::len`] /// for all types except Union, which does not have a validity. -pub fn new_null_array(data_type: ArrowDataType, length: usize) -> Box { +pub fn new_null_array(dtype: ArrowDataType, length: usize) -> Box { use crate::datatypes::PhysicalType::*; - match data_type.to_physical_type() { - Null => Box::new(NullArray::new_null(data_type, length)), - Boolean => Box::new(BooleanArray::new_null(data_type, length)), + match dtype.to_physical_type() { + Null => Box::new(NullArray::new_null(dtype, length)), + Boolean => Box::new(BooleanArray::new_null(dtype, length)), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - Box::new(PrimitiveArray::<$T>::new_null(data_type, length)) + Box::new(PrimitiveArray::<$T>::new_null(dtype, length)) }), - Binary => Box::new(BinaryArray::::new_null(data_type, length)), - LargeBinary => Box::new(BinaryArray::::new_null(data_type, length)), - FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_null(data_type, length)), - Utf8 => Box::new(Utf8Array::::new_null(data_type, length)), - LargeUtf8 => Box::new(Utf8Array::::new_null(data_type, length)), - List => Box::new(ListArray::::new_null(data_type, length)), - LargeList => Box::new(ListArray::::new_null(data_type, length)), - FixedSizeList => Box::new(FixedSizeListArray::new_null(data_type, length)), - Struct => Box::new(StructArray::new_null(data_type, length)), - Union => Box::new(UnionArray::new_null(data_type, length)), - Map => Box::new(MapArray::new_null(data_type, length)), - BinaryView => Box::new(BinaryViewArray::new_null(data_type, length)), - Utf8View => Box::new(Utf8ViewArray::new_null(data_type, length)), + Binary => Box::new(BinaryArray::::new_null(dtype, length)), + LargeBinary => Box::new(BinaryArray::::new_null(dtype, length)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_null(dtype, length)), + Utf8 => Box::new(Utf8Array::::new_null(dtype, length)), + LargeUtf8 => Box::new(Utf8Array::::new_null(dtype, length)), + List => Box::new(ListArray::::new_null(dtype, length)), + LargeList => Box::new(ListArray::::new_null(dtype, length)), + FixedSizeList => Box::new(FixedSizeListArray::new_null(dtype, length)), + Struct => Box::new(StructArray::new_null(dtype, length)), + Union => Box::new(UnionArray::new_null(dtype, length)), + Map => Box::new(MapArray::new_null(dtype, length)), + BinaryView => Box::new(BinaryViewArray::new_null(dtype, length)), + Utf8View => Box::new(Utf8ViewArray::new_null(dtype, length)), Dictionary(key_type) => { match_integer_type!(key_type, |$T| { - Box::new(DictionaryArray::<$T>::new_null(data_type, length)) + Box::new(DictionaryArray::<$T>::new_null(dtype, length)) }) }, } @@ -453,7 +455,7 @@ impl From<&dyn arrow_array::Array> for Box { #[cfg(feature = "arrow_rs")] pub fn to_data(array: &dyn Array) -> arrow_data::ArrayData { use crate::datatypes::PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => to_data_dyn!(array, NullArray), Boolean => to_data_dyn!(array, BooleanArray), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { @@ -483,8 +485,8 @@ pub fn to_data(array: &dyn Array) -> arrow_data::ArrayData { #[cfg(feature = "arrow_rs")] pub fn from_data(data: &arrow_data::ArrayData) -> Box { use crate::datatypes::PhysicalType::*; - let data_type: ArrowDataType = data.data_type().clone().into(); - match data_type.to_physical_type() { + let dtype: ArrowDataType = data.data_type().clone().into(); + match dtype.to_physical_type() { Null => Box::new(NullArray::from_data(data)), Boolean => Box::new(BooleanArray::from_data(data)), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { @@ -655,8 +657,8 @@ macro_rules! impl_common_array { } #[inline] - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } #[inline] @@ -697,7 +699,7 @@ macro_rules! impl_common_array { /// and moving the concrete struct under a `Box`. pub fn clone(array: &dyn Array) -> Box { use crate::datatypes::PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => clone_dyn!(array, NullArray), Boolean => clone_dyn!(array, BooleanArray), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { @@ -762,8 +764,8 @@ mod values; pub use binary::{BinaryArray, BinaryValueIter, MutableBinaryArray, MutableBinaryValuesArray}; pub use binview::{ - BinaryViewArray, BinaryViewArrayGeneric, MutableBinaryViewArray, MutablePlBinary, - MutablePlString, Utf8ViewArray, View, ViewType, + validate_utf8_view, BinaryViewArray, BinaryViewArrayGeneric, MutableBinaryViewArray, + MutablePlBinary, MutablePlString, Utf8ViewArray, View, ViewType, }; pub use boolean::{BooleanArray, MutableBooleanArray}; pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 22c3d6077686..4960b263667c 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -11,7 +11,7 @@ use crate::ffi; /// The concrete [`Array`] of [`ArrowDataType::Null`]. #[derive(Clone)] pub struct NullArray { - data_type: ArrowDataType, + dtype: ArrowDataType, /// Validity mask. This is always all-zeroes. validity: Bitmap, @@ -23,16 +23,16 @@ impl NullArray { /// Returns a new [`NullArray`]. /// # Errors /// This function errors iff: - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. - pub fn try_new(data_type: ArrowDataType, length: usize) -> PolarsResult { - if data_type.to_physical_type() != PhysicalType::Null { + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn try_new(dtype: ArrowDataType, length: usize) -> PolarsResult { + if dtype.to_physical_type() != PhysicalType::Null { polars_bail!(ComputeError: "NullArray can only be initialized with a DataType whose physical type is Null"); } let validity = Bitmap::new_zeroed(length); Ok(Self { - data_type, + dtype, validity, length, }) @@ -41,19 +41,19 @@ impl NullArray { /// Returns a new [`NullArray`]. /// # Panics /// This function errors iff: - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. - pub fn new(data_type: ArrowDataType, length: usize) -> Self { - Self::try_new(data_type, length).unwrap() + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(dtype: ArrowDataType, length: usize) -> Self { + Self::try_new(dtype, length).unwrap() } /// Returns a new empty [`NullArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - Self::new(data_type, 0) + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, 0) } /// Returns a new [`NullArray`]. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - Self::new(data_type, length) + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + Self::new(dtype, length) } impl_sliced!(); @@ -111,9 +111,9 @@ impl MutableNullArray { /// Returns a new [`MutableNullArray`]. /// # Panics /// This function errors iff: - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. - pub fn new(data_type: ArrowDataType, length: usize) -> Self { - let inner = NullArray::try_new(data_type, length).unwrap(); + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(dtype: ArrowDataType, length: usize) -> Self { + let inner = NullArray::try_new(dtype, length).unwrap(); Self { inner } } } @@ -125,7 +125,7 @@ impl From for NullArray { } impl MutableArray for MutableNullArray { - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { &ArrowDataType::Null } @@ -194,12 +194,12 @@ impl Splitable for NullArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity: lhs, length: offset, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity: rhs, length: self.len() - offset, }, @@ -209,8 +209,8 @@ impl Splitable for NullArray { impl FromFfi for NullArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); - Self::try_new(data_type, array.array().len()) + let dtype = array.dtype().clone(); + Self::try_new(dtype, array.array().len()) } } diff --git a/crates/polars-arrow/src/array/primitive/data.rs b/crates/polars-arrow/src/array/primitive/data.rs index 1a32b230f54f..56a94107cb89 100644 --- a/crates/polars-arrow/src/array/primitive/data.rs +++ b/crates/polars-arrow/src/array/primitive/data.rs @@ -7,9 +7,9 @@ use crate::types::NativeType; impl Arrow2Arrow for PrimitiveArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); + let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let builder = ArrayDataBuilder::new(dtype) .len(self.len()) .buffers(vec![self.values.clone().into()]) .nulls(self.validity.as_ref().map(|b| b.clone().into())); @@ -19,13 +19,13 @@ impl Arrow2Arrow for PrimitiveArray { } fn from_data(data: &ArrayData) -> Self { - let data_type = data.data_type().clone().into(); + let dtype = data.data_type().clone().into(); let mut values: Buffer = data.buffers()[0].clone().into(); values.slice(data.offset(), data.len()); Self { - data_type, + dtype, values, validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), } diff --git a/crates/polars-arrow/src/array/primitive/ffi.rs b/crates/polars-arrow/src/array/primitive/ffi.rs index ae22cf2e9a9c..6dae1963dd74 100644 --- a/crates/polars-arrow/src/array/primitive/ffi.rs +++ b/crates/polars-arrow/src/array/primitive/ffi.rs @@ -39,7 +39,7 @@ unsafe impl ToFfi for PrimitiveArray { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, values: self.values.clone(), } @@ -48,10 +48,10 @@ unsafe impl ToFfi for PrimitiveArray { impl FromFfi for PrimitiveArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let values = unsafe { array.buffer::(1) }?; - Self::try_new(data_type, values, validity) + Self::try_new(dtype, values, validity) } } diff --git a/crates/polars-arrow/src/array/primitive/fmt.rs b/crates/polars-arrow/src/array/primitive/fmt.rs index 1b3c5776b180..b349e01843d8 100644 --- a/crates/polars-arrow/src/array/primitive/fmt.rs +++ b/crates/polars-arrow/src/array/primitive/fmt.rs @@ -22,7 +22,7 @@ pub fn get_write_value<'a, T: NativeType, F: Write>( array: &'a PrimitiveArray, ) -> Box Result + 'a> { use crate::datatypes::ArrowDataType::*; - match array.data_type().to_logical_type() { + match array.dtype().to_logical_type() { Int8 => Box::new(|f, index| write!(f, "{}", array.value(index))), Int16 => Box::new(|f, index| write!(f, "{}", array.value(index))), Int32 => Box::new(|f, index| write!(f, "{}", array.value(index))), @@ -56,7 +56,7 @@ pub fn get_write_value<'a, T: NativeType, F: Write>( Time64(_) => unreachable!(), // remaining are not valid Timestamp(time_unit, tz) => { if let Some(tz) = tz { - let timezone = temporal_conversions::parse_offset(tz); + let timezone = temporal_conversions::parse_offset(tz.as_str()); match timezone { Ok(timezone) => { dyn_primitive!(array, i64, |time| { @@ -65,7 +65,7 @@ pub fn get_write_value<'a, T: NativeType, F: Write>( }, #[cfg(feature = "chrono-tz")] Err(_) => { - let timezone = temporal_conversions::parse_offset_tz(tz); + let timezone = temporal_conversions::parse_offset_tz(tz.as_str()); match timezone { Ok(timezone) => dyn_primitive!(array, i64, |time| { temporal_conversions::timestamp_to_datetime( @@ -143,7 +143,7 @@ impl Debug for PrimitiveArray { fn fmt(&self, f: &mut Formatter<'_>) -> Result { let writer = get_write_value(self); - write!(f, "{:?}", self.data_type())?; + write!(f, "{:?}", self.dtype())?; write_vec(f, &*writer, self.validity(), self.len(), "None", false) } } diff --git a/crates/polars-arrow/src/array/primitive/iterator.rs b/crates/polars-arrow/src/array/primitive/iterator.rs index 6c36fca2b2dc..36f680974674 100644 --- a/crates/polars-arrow/src/array/primitive/iterator.rs +++ b/crates/polars-arrow/src/array/primitive/iterator.rs @@ -1,5 +1,3 @@ -use polars_utils::iter::IntoIteratorCopied; - use super::{MutablePrimitiveArray, PrimitiveArray}; use crate::array::{ArrayAccessor, MutableArray}; use crate::bitmap::utils::{BitmapIter, ZipValidity}; @@ -61,12 +59,3 @@ impl<'a, T: NativeType> MutablePrimitiveArray { self.values().iter() } } - -impl IntoIteratorCopied for PrimitiveArray { - type OwnedItem = Option; - type IntoIterCopied = Self::IntoIter; - - fn into_iter(self) -> ::IntoIterCopied { - ::into_iter(self) - } -} diff --git a/crates/polars-arrow/src/array/primitive/mod.rs b/crates/polars-arrow/src/array/primitive/mod.rs index 5ab10deaf47c..5f2d351af535 100644 --- a/crates/polars-arrow/src/array/primitive/mod.rs +++ b/crates/polars-arrow/src/array/primitive/mod.rs @@ -51,13 +51,13 @@ use polars_utils::slice::{GetSaferUnchecked, SliceAble}; /// ``` #[derive(Clone)] pub struct PrimitiveArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: Buffer, validity: Option, } pub(super) fn check( - data_type: &ArrowDataType, + dtype: &ArrowDataType, values: &[T], validity_len: Option, ) -> PolarsResult<()> { @@ -65,7 +65,7 @@ pub(super) fn check( polars_bail!(ComputeError: "validity mask length must match the number of values") } - if data_type.to_physical_type() != PhysicalType::Primitive(T::PRIMITIVE) { + if dtype.to_physical_type() != PhysicalType::Primitive(T::PRIMITIVE) { polars_bail!(ComputeError: "PrimitiveArray can only be initialized with a DataType whose physical type is Primitive") } Ok(()) @@ -79,15 +79,15 @@ impl PrimitiveArray { /// # Errors /// This function errors iff: /// * The validity is not `None` and its length is different from `values`'s length - /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Buffer, validity: Option, ) -> PolarsResult { - check(&data_type, &values, validity.as_ref().map(|v| v.len()))?; + check(&dtype, &values, validity.as_ref().map(|v| v.len()))?; Ok(Self { - data_type, + dtype, values, validity, }) @@ -96,12 +96,12 @@ impl PrimitiveArray { /// # Safety /// Doesn't check invariants pub unsafe fn new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Buffer, validity: Option, ) -> Self { Self { - data_type, + dtype, values, validity, } @@ -123,18 +123,18 @@ impl PrimitiveArray { /// ); /// ``` /// # Panics - /// Panics iff the `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + /// Panics iff the `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] #[inline] #[must_use] - pub fn to(self, data_type: ArrowDataType) -> Self { + pub fn to(self, dtype: ArrowDataType) -> Self { check( - &data_type, + &dtype, &self.values, self.validity.as_ref().map(|v| v.len()), ) .unwrap(); Self { - data_type, + dtype, values: self.values, validity: self.validity, } @@ -192,8 +192,8 @@ impl PrimitiveArray { /// Returns the arrays' [`ArrowDataType`]. #[inline] - pub fn data_type(&self) -> &ArrowDataType { - &self.data_type + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype } /// Returns the value at slot `i`. @@ -302,22 +302,22 @@ impl PrimitiveArray { #[must_use] pub fn into_inner(self) -> (ArrowDataType, Buffer, Option) { let Self { - data_type, + dtype, values, validity, } = self; - (data_type, values, validity) + (dtype, values, validity) } /// Creates a `[PrimitiveArray]` from its internal representation. /// This is the inverted from `[PrimitiveArray::into_inner]` pub fn from_inner( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Buffer, validity: Option, ) -> PolarsResult { - check(&data_type, &values, validity.as_ref().map(|v| v.len()))?; - Ok(unsafe { Self::from_inner_unchecked(data_type, values, validity) }) + check(&dtype, &values, validity.as_ref().map(|v| v.len()))?; + Ok(unsafe { Self::from_inner_unchecked(dtype, values, validity) }) } /// Creates a `[PrimitiveArray]` from its internal representation. @@ -326,12 +326,12 @@ impl PrimitiveArray { /// # Safety /// Callers must ensure all invariants of this struct are upheld. pub unsafe fn from_inner_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Buffer, validity: Option, ) -> Self { Self { - data_type, + dtype, values, validity, } @@ -350,22 +350,14 @@ impl PrimitiveArray { if let Some(bitmap) = self.validity { match bitmap.into_mut() { - Left(bitmap) => Left(PrimitiveArray::new( - self.data_type, - self.values, - Some(bitmap), - )), + Left(bitmap) => Left(PrimitiveArray::new(self.dtype, self.values, Some(bitmap))), Right(mutable_bitmap) => match self.values.into_mut() { Right(values) => Right( - MutablePrimitiveArray::try_new( - self.data_type, - values, - Some(mutable_bitmap), - ) - .unwrap(), + MutablePrimitiveArray::try_new(self.dtype, values, Some(mutable_bitmap)) + .unwrap(), ), Left(values) => Left(PrimitiveArray::new( - self.data_type, + self.dtype, values, Some(mutable_bitmap.into()), )), @@ -374,23 +366,23 @@ impl PrimitiveArray { } else { match self.values.into_mut() { Right(values) => { - Right(MutablePrimitiveArray::try_new(self.data_type, values, None).unwrap()) + Right(MutablePrimitiveArray::try_new(self.dtype, values, None).unwrap()) }, - Left(values) => Left(PrimitiveArray::new(self.data_type, values, None)), + Left(values) => Left(PrimitiveArray::new(self.dtype, values, None)), } } } /// Returns a new empty (zero-length) [`PrimitiveArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - Self::new(data_type, Buffer::new(), None) + pub fn new_empty(dtype: ArrowDataType) -> Self { + Self::new(dtype, Buffer::new(), None) } /// Returns a new [`PrimitiveArray`] where all slots are null / `None`. #[inline] - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { Self::new( - data_type, + dtype, vec![T::default(); length].into(), Some(Bitmap::new_zeroed(length)), ) @@ -448,9 +440,9 @@ impl PrimitiveArray { /// # Panics /// This function errors iff: /// * The validity is not `None` and its length is different from `values`'s length - /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive`]. - pub fn new(data_type: ArrowDataType, values: Buffer, validity: Option) -> Self { - Self::try_new(data_type, values, validity).unwrap() + /// * The `dtype`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive`]. + pub fn new(dtype: ArrowDataType, values: Buffer, validity: Option) -> Self { + Self::try_new(dtype, values, validity).unwrap() } /// Transmute this PrimitiveArray into another PrimitiveArray. @@ -510,12 +502,12 @@ impl Splitable for PrimitiveArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: lhs_values, validity: lhs_validity, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: rhs_values, validity: rhs_validity, }, diff --git a/crates/polars-arrow/src/array/primitive/mutable.rs b/crates/polars-arrow/src/array/primitive/mutable.rs index ae2025482f2c..ab6bfd8c7511 100644 --- a/crates/polars-arrow/src/array/primitive/mutable.rs +++ b/crates/polars-arrow/src/array/primitive/mutable.rs @@ -14,7 +14,7 @@ use crate::types::NativeType; /// Converting a [`MutablePrimitiveArray`] into a [`PrimitiveArray`] is `O(1)`. #[derive(Debug, Clone)] pub struct MutablePrimitiveArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec, validity: Option, } @@ -30,7 +30,7 @@ impl From> for PrimitiveArray { } }); - PrimitiveArray::::new(other.data_type, other.values.into(), validity) + PrimitiveArray::::new(other.dtype, other.values.into(), validity) } } @@ -58,15 +58,15 @@ impl MutablePrimitiveArray { /// # Errors /// This function errors iff: /// * The validity is not `None` and its length is different from `values`'s length - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Primitive(T::PRIMITIVE)`] + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Primitive(T::PRIMITIVE)`] pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec, validity: Option, ) -> PolarsResult { - check(&data_type, &values, validity.as_ref().map(|x| x.len()))?; + check(&dtype, &values, validity.as_ref().map(|x| x.len()))?; Ok(Self { - data_type, + dtype, values, validity, }) @@ -74,7 +74,7 @@ impl MutablePrimitiveArray { /// Extract the low-end APIs from the [`MutablePrimitiveArray`]. pub fn into_inner(self) -> (ArrowDataType, Vec, Option) { - (self.data_type, self.values, self.validity) + (self.dtype, self.values, self.validity) } /// Applies a function `f` to the values of this array, cloning the values @@ -98,10 +98,10 @@ impl Default for MutablePrimitiveArray { } impl From for MutablePrimitiveArray { - fn from(data_type: ArrowDataType) -> Self { - assert!(data_type.to_physical_type().eq_primitive(T::PRIMITIVE)); + fn from(dtype: ArrowDataType) -> Self { + assert!(dtype.to_physical_type().eq_primitive(T::PRIMITIVE)); Self { - data_type, + dtype, values: Vec::::new(), validity: None, } @@ -110,10 +110,10 @@ impl From for MutablePrimitiveArray { impl MutablePrimitiveArray { /// Creates a new [`MutablePrimitiveArray`] from a capacity and [`ArrowDataType`]. - pub fn with_capacity_from(capacity: usize, data_type: ArrowDataType) -> Self { - assert!(data_type.to_physical_type().eq_primitive(T::PRIMITIVE)); + pub fn with_capacity_from(capacity: usize, dtype: ArrowDataType) -> Self { + assert!(dtype.to_physical_type().eq_primitive(T::PRIMITIVE)); Self { - data_type, + dtype, values: Vec::::with_capacity(capacity), validity: None, } @@ -130,9 +130,8 @@ impl MutablePrimitiveArray { #[inline] pub fn push_value(&mut self, value: T) { self.values.push(value); - match &mut self.validity { - Some(validity) => validity.push(true), - None => {}, + if let Some(validity) = &mut self.validity { + validity.push(true) } } @@ -265,8 +264,8 @@ impl MutablePrimitiveArray { /// # Implementation /// This operation is `O(1)`. #[inline] - pub fn to(self, data_type: ArrowDataType) -> Self { - Self::try_new(data_type, self.values, self.validity).unwrap() + pub fn to(self, dtype: ArrowDataType) -> Self { + Self::try_new(dtype, self.values, self.validity).unwrap() } /// Converts itself into an [`Array`]. @@ -414,7 +413,7 @@ impl MutableArray for MutablePrimitiveArray { fn as_box(&mut self) -> Box { PrimitiveArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.values).into(), std::mem::take(&mut self.validity).map(|x| x.into()), ) @@ -423,15 +422,15 @@ impl MutableArray for MutablePrimitiveArray { fn as_arc(&mut self) -> Arc { PrimitiveArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.values).into(), std::mem::take(&mut self.validity).map(|x| x.into()), ) .arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { @@ -475,7 +474,7 @@ impl MutablePrimitiveArray { let (validity, values) = trusted_len_unzip(iterator); Self { - data_type: T::PRIMITIVE.into(), + dtype: T::PRIMITIVE.into(), values, validity, } @@ -509,7 +508,7 @@ impl MutablePrimitiveArray { let (validity, values) = try_trusted_len_unzip(iterator)?; Ok(Self { - data_type: T::PRIMITIVE.into(), + dtype: T::PRIMITIVE.into(), values, validity, }) @@ -528,7 +527,7 @@ impl MutablePrimitiveArray { /// Creates a new [`MutablePrimitiveArray`] out an iterator over values pub fn from_trusted_len_values_iter>(iter: I) -> Self { Self { - data_type: T::PRIMITIVE.into(), + dtype: T::PRIMITIVE.into(), values: iter.collect(), validity: None, } @@ -547,7 +546,7 @@ impl MutablePrimitiveArray { /// I.e. that `size_hint().1` correctly reports its length. pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { Self { - data_type: T::PRIMITIVE.into(), + dtype: T::PRIMITIVE.into(), values: iter.collect(), validity: None, } @@ -578,7 +577,7 @@ impl>> FromIterator let validity = Some(validity); Self { - data_type: T::PRIMITIVE.into(), + dtype: T::PRIMITIVE.into(), values, validity, } diff --git a/crates/polars-arrow/src/array/struct_/data.rs b/crates/polars-arrow/src/array/struct_/data.rs index 4dfcb0010a73..ca8c5b0c6ec3 100644 --- a/crates/polars-arrow/src/array/struct_/data.rs +++ b/crates/polars-arrow/src/array/struct_/data.rs @@ -5,9 +5,9 @@ use crate::bitmap::Bitmap; impl Arrow2Arrow for StructArray { fn to_data(&self) -> ArrayData { - let data_type = self.data_type.clone().into(); + let dtype = self.dtype.clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let builder = ArrayDataBuilder::new(dtype) .len(self.len()) .nulls(self.validity.as_ref().map(|b| b.clone().into())) .child_data(self.values.iter().map(|x| to_data(x.as_ref())).collect()); @@ -17,10 +17,10 @@ impl Arrow2Arrow for StructArray { } fn from_data(data: &ArrayData) -> Self { - let data_type = data.data_type().clone().into(); + let dtype = data.data_type().clone().into(); Self { - data_type, + dtype, values: data.child_data().iter().map(from_data).collect(), validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), } diff --git a/crates/polars-arrow/src/array/struct_/ffi.rs b/crates/polars-arrow/src/array/struct_/ffi.rs index 76522b8efd7c..3bfb9a1a7d7f 100644 --- a/crates/polars-arrow/src/array/struct_/ffi.rs +++ b/crates/polars-arrow/src/array/struct_/ffi.rs @@ -30,8 +30,8 @@ unsafe impl ToFfi for StructArray { impl FromFfi for StructArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); - let fields = Self::get_fields(&data_type); + let dtype = array.dtype().clone(); + let fields = Self::get_fields(&dtype); let arrow_array = array.array(); let validity = unsafe { array.validity() }?; @@ -68,6 +68,6 @@ impl FromFfi for StructArray { }) .collect::>>>()?; - Self::try_new(data_type, values, validity) + Self::try_new(dtype, values, validity) } } diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs index 08a56aa0fee1..efac13a481ea 100644 --- a/crates/polars-arrow/src/array/struct_/mod.rs +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -23,15 +23,15 @@ use crate::compute::utils::combine_validities_and; /// let int = Int32Array::from_slice(&[42, 28, 19, 31]).boxed(); /// /// let fields = vec![ -/// Field::new("b", ArrowDataType::Boolean, false), -/// Field::new("c", ArrowDataType::Int32, false), +/// Field::new("b".into(), ArrowDataType::Boolean, false), +/// Field::new("c".into(), ArrowDataType::Int32, false), /// ]; /// /// let array = StructArray::new(ArrowDataType::Struct(fields), vec![boolean, int], None); /// ``` #[derive(Clone)] pub struct StructArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec>, validity: Option, } @@ -40,34 +40,40 @@ impl StructArray { /// Returns a new [`StructArray`]. /// # Errors /// This function errors iff: - /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. - /// * the children of `data_type` are empty + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `dtype` are empty /// * the values's len is different from children's length /// * any of the values's data type is different from its corresponding children' data type /// * any element of values has a different length than the first element /// * the validity's length is not equal to the length of the first element pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec>, validity: Option, ) -> PolarsResult { - let fields = Self::try_get_fields(&data_type)?; + let fields = Self::try_get_fields(&dtype)?; if fields.is_empty() { - polars_bail!(ComputeError: "a StructArray must contain at least one field") + assert!(values.is_empty(), "invalid struct"); + assert_eq!(validity.map(|v| v.len()).unwrap_or(0), 0, "invalid struct"); + return Ok(Self { + dtype, + values, + validity: None, + }); } if fields.len() != values.len() { polars_bail!(ComputeError:"a StructArray must have a number of fields in its DataType equal to the number of child values") } fields - .iter().map(|a| &a.data_type) - .zip(values.iter().map(|a| a.data_type())) + .iter().map(|a| &a.dtype) + .zip(values.iter().map(|a| a.dtype())) .enumerate() - .try_for_each(|(index, (data_type, child))| { - if data_type != child { + .try_for_each(|(index, (dtype, child))| { + if dtype != child { polars_bail!(ComputeError: "The children DataTypes of a StructArray must equal the children data types. - However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + However, the field {index} has data type {dtype:?} but the value has data type {child:?}" ) } else { Ok(()) @@ -96,7 +102,7 @@ impl StructArray { } Ok(Self { - data_type, + dtype, values, validity, }) @@ -105,41 +111,41 @@ impl StructArray { /// Returns a new [`StructArray`] /// # Panics /// This function panics iff: - /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. - /// * the children of `data_type` are empty + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `dtype` are empty /// * the values's len is different from children's length /// * any of the values's data type is different from its corresponding children' data type /// * any element of values has a different length than the first element /// * the validity's length is not equal to the length of the first element pub fn new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec>, validity: Option, ) -> Self { - Self::try_new(data_type, values, validity).unwrap() + Self::try_new(dtype, values, validity).unwrap() } /// Creates an empty [`StructArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - if let ArrowDataType::Struct(fields) = &data_type.to_logical_type() { + pub fn new_empty(dtype: ArrowDataType) -> Self { + if let ArrowDataType::Struct(fields) = &dtype.to_logical_type() { let values = fields .iter() - .map(|field| new_empty_array(field.data_type().clone())) + .map(|field| new_empty_array(field.dtype().clone())) .collect(); - Self::new(data_type, values, None) + Self::new(dtype, values, None) } else { panic!("StructArray must be initialized with DataType::Struct"); } } /// Creates a null [`StructArray`] of length `length`. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - if let ArrowDataType::Struct(fields) = &data_type { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + if let ArrowDataType::Struct(fields) = &dtype { let values = fields .iter() - .map(|field| new_null_array(field.data_type().clone(), length)) + .map(|field| new_null_array(field.dtype().clone(), length)) .collect(); - Self::new(data_type, values, Some(Bitmap::new_zeroed(length))) + Self::new(dtype, values, Some(Bitmap::new_zeroed(length))) } else { panic!("StructArray must be initialized with DataType::Struct"); } @@ -152,11 +158,11 @@ impl StructArray { #[must_use] pub fn into_data(self) -> (Vec, Vec>, Option) { let Self { - data_type, + dtype, values, validity, } = self; - let fields = if let ArrowDataType::Struct(fields) = data_type { + let fields = if let ArrowDataType::Struct(fields) = dtype { fields } else { unreachable!() @@ -220,7 +226,7 @@ impl StructArray { impl StructArray { #[inline] fn len(&self) -> usize { - self.values[0].len() + self.values.first().map(|arr| arr.len()).unwrap_or(0) } /// The optional validity. @@ -236,14 +242,14 @@ impl StructArray { /// Returns the fields of this [`StructArray`]. pub fn fields(&self) -> &[Field] { - Self::get_fields(&self.data_type) + Self::get_fields(&self.dtype) } } impl StructArray { /// Returns the fields the `DataType::Struct`. - pub(crate) fn try_get_fields(data_type: &ArrowDataType) -> PolarsResult<&[Field]> { - match data_type.to_logical_type() { + pub(crate) fn try_get_fields(dtype: &ArrowDataType) -> PolarsResult<&[Field]> { + match dtype.to_logical_type() { ArrowDataType::Struct(fields) => Ok(fields), _ => { polars_bail!(ComputeError: "Struct array must be created with a DataType whose physical type is Struct") @@ -252,8 +258,8 @@ impl StructArray { } /// Returns the fields the `DataType::Struct`. - pub fn get_fields(data_type: &ArrowDataType) -> &[Field] { - Self::try_get_fields(data_type).unwrap() + pub fn get_fields(dtype: &ArrowDataType) -> &[Field] { + Self::try_get_fields(dtype).unwrap() } } @@ -289,12 +295,12 @@ impl Splitable for StructArray { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: lhs_values, validity: lhs_validity, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), values: rhs_values, validity: rhs_validity, }, diff --git a/crates/polars-arrow/src/array/struct_/mutable.rs b/crates/polars-arrow/src/array/struct_/mutable.rs index d748f7743b32..286db07e2f97 100644 --- a/crates/polars-arrow/src/array/struct_/mutable.rs +++ b/crates/polars-arrow/src/array/struct_/mutable.rs @@ -10,17 +10,17 @@ use crate::datatypes::ArrowDataType; /// Converting a [`MutableStructArray`] into a [`StructArray`] is `O(1)`. #[derive(Debug)] pub struct MutableStructArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec>, validity: Option, } fn check( - data_type: &ArrowDataType, + dtype: &ArrowDataType, values: &[Box], validity: Option, ) -> PolarsResult<()> { - let fields = StructArray::try_get_fields(data_type)?; + let fields = StructArray::try_get_fields(dtype)?; if fields.is_empty() { polars_bail!(ComputeError: "a StructArray must contain at least one field") } @@ -29,14 +29,14 @@ fn check( } fields - .iter().map(|a| &a.data_type) - .zip(values.iter().map(|a| a.data_type())) + .iter().map(|a| &a.dtype) + .zip(values.iter().map(|a| a.dtype())) .enumerate() - .try_for_each(|(index, (data_type, child))| { - if data_type != child { + .try_for_each(|(index, (dtype, child))| { + if dtype != child { polars_bail!(ComputeError: "The children DataTypes of a StructArray must equal the children data types. - However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + However, the field {index} has data type {dtype:?} but the value has data type {child:?}" ) } else { Ok(()) @@ -76,7 +76,7 @@ impl From for StructArray { }; StructArray::new( - other.data_type, + other.dtype, other.values.into_iter().map(|mut v| v.as_box()).collect(), validity, ) @@ -85,24 +85,24 @@ impl From for StructArray { impl MutableStructArray { /// Creates a new [`MutableStructArray`]. - pub fn new(data_type: ArrowDataType, values: Vec>) -> Self { - Self::try_new(data_type, values, None).unwrap() + pub fn new(dtype: ArrowDataType, values: Vec>) -> Self { + Self::try_new(dtype, values, None).unwrap() } /// Create a [`MutableStructArray`] out of low-end APIs. /// # Errors /// This function errors iff: - /// * `data_type` is not [`ArrowDataType::Struct`] - /// * The inner types of `data_type` are not equal to those of `values` + /// * `dtype` is not [`ArrowDataType::Struct`] + /// * The inner types of `dtype` are not equal to those of `values` /// * `validity` is not `None` and its length is different from the `values`'s length pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec>, validity: Option, ) -> PolarsResult { - check(&data_type, &values, validity.as_ref().map(|x| x.len()))?; + check(&dtype, &values, validity.as_ref().map(|x| x.len()))?; Ok(Self { - data_type, + dtype, values, validity, }) @@ -116,7 +116,7 @@ impl MutableStructArray { Vec>, Option, ) { - (self.data_type, self.values, self.validity) + (self.dtype, self.values, self.validity) } /// The mutable values @@ -202,7 +202,7 @@ impl MutableArray for MutableStructArray { fn as_box(&mut self) -> Box { StructArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.values) .into_iter() .map(|mut v| v.as_box()) @@ -214,7 +214,7 @@ impl MutableArray for MutableStructArray { fn as_arc(&mut self) -> Arc { StructArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.values) .into_iter() .map(|mut v| v.as_box()) @@ -224,8 +224,8 @@ impl MutableArray for MutableStructArray { .arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/array/union/data.rs b/crates/polars-arrow/src/array/union/data.rs index 4303ab7b4356..869fdcfc248d 100644 --- a/crates/polars-arrow/src/array/union/data.rs +++ b/crates/polars-arrow/src/array/union/data.rs @@ -6,15 +6,15 @@ use crate::datatypes::ArrowDataType; impl Arrow2Arrow for UnionArray { fn to_data(&self) -> ArrayData { - let data_type = arrow_schema::DataType::from(self.data_type.clone()); + let dtype = arrow_schema::DataType::from(self.dtype.clone()); let len = self.len(); let builder = match self.offsets.clone() { - Some(offsets) => ArrayDataBuilder::new(data_type) + Some(offsets) => ArrayDataBuilder::new(dtype) .len(len) .buffers(vec![self.types.clone().into(), offsets.into()]) .child_data(self.fields.iter().map(|x| to_data(x.as_ref())).collect()), - None => ArrayDataBuilder::new(data_type) + None => ArrayDataBuilder::new(dtype) .len(len) .buffers(vec![self.types.clone().into()]) .child_data( @@ -30,7 +30,7 @@ impl Arrow2Arrow for UnionArray { } fn from_data(data: &ArrayData) -> Self { - let data_type: ArrowDataType = data.data_type().clone().into(); + let dtype: ArrowDataType = data.data_type().clone().into(); let fields = data.child_data().iter().map(from_data).collect(); let buffers = data.buffers(); @@ -46,7 +46,7 @@ impl Arrow2Arrow for UnionArray { }; // Map from type id to array index - let map = match &data_type { + let map = match &dtype { ArrowDataType::Union(_, Some(ids), _) => { let mut map = [0; 127]; for (pos, &id) in ids.iter().enumerate() { @@ -63,7 +63,7 @@ impl Arrow2Arrow for UnionArray { map, fields, offsets, - data_type, + dtype, offset: data.offset(), } } diff --git a/crates/polars-arrow/src/array/union/ffi.rs b/crates/polars-arrow/src/array/union/ffi.rs index 1510b29e2588..d9a2601a6019 100644 --- a/crates/polars-arrow/src/array/union/ffi.rs +++ b/crates/polars-arrow/src/array/union/ffi.rs @@ -33,11 +33,11 @@ unsafe impl ToFfi for UnionArray { impl FromFfi for UnionArray { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); - let fields = Self::get_fields(&data_type); + let dtype = array.dtype().clone(); + let fields = Self::get_fields(&dtype); let mut types = unsafe { array.buffer::(0) }?; - let offsets = if Self::is_sparse(&data_type) { + let offsets = if Self::is_sparse(&dtype) { None } else { Some(unsafe { array.buffer::(1) }?) @@ -56,6 +56,6 @@ impl FromFfi for UnionArray { types.slice(offset, length); }; - Self::try_new(data_type, types, fields, offsets) + Self::try_new(dtype, types, fields, offsets) } } diff --git a/crates/polars-arrow/src/array/union/mod.rs b/crates/polars-arrow/src/array/union/mod.rs index d1221b812eae..e42d268f5c06 100644 --- a/crates/polars-arrow/src/array/union/mod.rs +++ b/crates/polars-arrow/src/array/union/mod.rs @@ -34,7 +34,7 @@ pub struct UnionArray { fields: Vec>, // Invariant: when set, `offsets.len() == types.len()` offsets: Option>, - data_type: ArrowDataType, + dtype: ArrowDataType, offset: usize, } @@ -42,17 +42,17 @@ impl UnionArray { /// Returns a new [`UnionArray`]. /// # Errors /// This function errors iff: - /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. - /// * the fields's len is different from the `data_type`'s children's length + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `dtype`'s children's length /// * The number of `fields` is larger than `i8::MAX` /// * any of the values's data type is different from its corresponding children' data type pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, types: Buffer, fields: Vec>, offsets: Option>, ) -> PolarsResult { - let (f, ids, mode) = Self::try_get_all(&data_type)?; + let (f, ids, mode) = Self::try_get_all(&dtype)?; if f.len() != fields.len() { polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union") @@ -62,14 +62,14 @@ impl UnionArray { )?; f - .iter().map(|a| a.data_type()) - .zip(fields.iter().map(|a| a.data_type())) + .iter().map(|a| a.dtype()) + .zip(fields.iter().map(|a| a.dtype())) .enumerate() - .try_for_each(|(index, (data_type, child))| { - if data_type != child { + .try_for_each(|(index, (dtype, child))| { + if dtype != child { polars_bail!(ComputeError: "the children DataTypes of a UnionArray must equal the children data types. - However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + However, the field {index} has data type {dtype:?} but the value has data type {child:?}" ) } else { Ok(()) @@ -147,7 +147,7 @@ impl UnionArray { }; Ok(Self { - data_type, + dtype, map, fields, offsets, @@ -159,24 +159,24 @@ impl UnionArray { /// Returns a new [`UnionArray`]. /// # Panics /// This function panics iff: - /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. - /// * the fields's len is different from the `data_type`'s children's length + /// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `dtype`'s children's length /// * any of the values's data type is different from its corresponding children' data type pub fn new( - data_type: ArrowDataType, + dtype: ArrowDataType, types: Buffer, fields: Vec>, offsets: Option>, ) -> Self { - Self::try_new(data_type, types, fields, offsets).unwrap() + Self::try_new(dtype, types, fields, offsets).unwrap() } /// Creates a new null [`UnionArray`]. - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { - if let ArrowDataType::Union(f, _, mode) = &data_type { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { + if let ArrowDataType::Union(f, _, mode) = &dtype { let fields = f .iter() - .map(|x| new_null_array(x.data_type().clone(), length)) + .map(|x| new_null_array(x.dtype().clone(), length)) .collect(); let offsets = if mode.is_sparse() { @@ -188,18 +188,18 @@ impl UnionArray { // all from the same field let types = vec![0i8; length].into(); - Self::new(data_type, types, fields, offsets) + Self::new(dtype, types, fields, offsets) } else { panic!("Union struct must be created with the corresponding Union DataType") } } /// Creates a new empty [`UnionArray`]. - pub fn new_empty(data_type: ArrowDataType) -> Self { - if let ArrowDataType::Union(f, _, mode) = data_type.to_logical_type() { + pub fn new_empty(dtype: ArrowDataType) -> Self { + if let ArrowDataType::Union(f, _, mode) = dtype.to_logical_type() { let fields = f .iter() - .map(|x| new_empty_array(x.data_type().clone())) + .map(|x| new_empty_array(x.dtype().clone())) .collect(); let offsets = if mode.is_sparse() { @@ -209,7 +209,7 @@ impl UnionArray { }; Self { - data_type, + dtype, map: None, fields, offsets, @@ -351,8 +351,8 @@ impl Array for UnionArray { } impl UnionArray { - fn try_get_all(data_type: &ArrowDataType) -> PolarsResult { - match data_type.to_logical_type() { + fn try_get_all(dtype: &ArrowDataType) -> PolarsResult { + match dtype.to_logical_type() { ArrowDataType::Union(fields, ids, mode) => { Ok((fields, ids.as_ref().map(|x| x.as_ref()), *mode)) }, @@ -362,22 +362,22 @@ impl UnionArray { } } - fn get_all(data_type: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) { - Self::try_get_all(data_type).unwrap() + fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) { + Self::try_get_all(dtype).unwrap() } /// Returns all fields from [`ArrowDataType::Union`]. /// # Panic - /// Panics iff `data_type`'s logical type is not [`ArrowDataType::Union`]. - pub fn get_fields(data_type: &ArrowDataType) -> &[Field] { - Self::get_all(data_type).0 + /// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`]. + pub fn get_fields(dtype: &ArrowDataType) -> &[Field] { + Self::get_all(dtype).0 } /// Returns whether the [`ArrowDataType::Union`] is sparse or not. /// # Panic - /// Panics iff `data_type`'s logical type is not [`ArrowDataType::Union`]. - pub fn is_sparse(data_type: &ArrowDataType) -> bool { - Self::get_all(data_type).2.is_sparse() + /// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`]. + pub fn is_sparse(dtype: &ArrowDataType) -> bool { + Self::get_all(dtype).2.is_sparse() } } @@ -399,7 +399,7 @@ impl Splitable for UnionArray { map: self.map, fields: self.fields.clone(), offsets: lhs_offsets, - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offset: self.offset, }, Self { @@ -407,7 +407,7 @@ impl Splitable for UnionArray { map: self.map, fields: self.fields.clone(), offsets: rhs_offsets, - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offset: self.offset + offset, }, ) diff --git a/crates/polars-arrow/src/array/utf8/data.rs b/crates/polars-arrow/src/array/utf8/data.rs index 577a43677c05..37f73a089aa6 100644 --- a/crates/polars-arrow/src/array/utf8/data.rs +++ b/crates/polars-arrow/src/array/utf8/data.rs @@ -6,8 +6,8 @@ use crate::offset::{Offset, OffsetsBuffer}; impl Arrow2Arrow for Utf8Array { fn to_data(&self) -> ArrayData { - let data_type = self.data_type().clone().into(); - let builder = ArrayDataBuilder::new(data_type) + let dtype = self.dtype().clone().into(); + let builder = ArrayDataBuilder::new(dtype) .len(self.offsets().len_proxy()) .buffers(vec![ self.offsets.clone().into_inner().into(), @@ -20,10 +20,10 @@ impl Arrow2Arrow for Utf8Array { } fn from_data(data: &ArrayData) -> Self { - let data_type = data.data_type().clone().into(); + let dtype = data.data_type().clone().into(); if data.is_empty() { // Handle empty offsets - return Self::new_empty(data_type); + return Self::new_empty(dtype); } let buffers = data.buffers(); @@ -33,7 +33,7 @@ impl Arrow2Arrow for Utf8Array { offsets.slice(data.offset(), data.len() + 1); Self { - data_type, + dtype, offsets, values: buffers[1].clone().into(), validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), diff --git a/crates/polars-arrow/src/array/utf8/ffi.rs b/crates/polars-arrow/src/array/utf8/ffi.rs index 5bdced4df6f1..7181eba91286 100644 --- a/crates/polars-arrow/src/array/utf8/ffi.rs +++ b/crates/polars-arrow/src/array/utf8/ffi.rs @@ -40,7 +40,7 @@ unsafe impl ToFfi for Utf8Array { }); Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), validity, offsets: self.offsets.clone(), values: self.values.clone(), @@ -50,7 +50,7 @@ unsafe impl ToFfi for Utf8Array { impl FromFfi for Utf8Array { unsafe fn try_from_ffi(array: A) -> PolarsResult { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let validity = unsafe { array.validity() }?; let offsets = unsafe { array.buffer::(1) }?; let values = unsafe { array.buffer::(2)? }; @@ -58,6 +58,6 @@ impl FromFfi for Utf8Array { // assumption that data from FFI is well constructed let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; - Ok(Self::new_unchecked(data_type, offsets, values, validity)) + Ok(Self::new_unchecked(dtype, offsets, values, validity)) } } diff --git a/crates/polars-arrow/src/array/utf8/mod.rs b/crates/polars-arrow/src/array/utf8/mod.rs index 03c4ca1caabb..ebec52b78d28 100644 --- a/crates/polars-arrow/src/array/utf8/mod.rs +++ b/crates/polars-arrow/src/array/utf8/mod.rs @@ -65,7 +65,7 @@ impl> AsRef<[u8]> for StrAsBytes { /// * `len` is equal to `validity.len()`, when defined. #[derive(Clone)] pub struct Utf8Array { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, @@ -79,12 +79,12 @@ impl Utf8Array { /// This function returns an error iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. /// * The `values` between two consecutive `offsets` are not valid utf8 /// # Implementation /// This function is `O(N)` - checking utf8 is `O(N)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, @@ -97,12 +97,12 @@ impl Utf8Array { polars_bail!(ComputeError: "validity mask length must match the number of values"); } - if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { polars_bail!(ComputeError: "Utf8Array can only be initialized with DataType::Utf8 or DataType::LargeUtf8") } Ok(Self { - data_type, + dtype, offsets, values, validity, @@ -186,8 +186,8 @@ impl Utf8Array { /// Returns the [`ArrowDataType`] of this array. #[inline] - pub fn data_type(&self) -> &ArrowDataType { - &self.data_type + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype } /// Returns the values of this [`Utf8Array`]. @@ -244,12 +244,12 @@ impl Utf8Array { #[must_use] pub fn into_inner(self) -> (ArrowDataType, OffsetsBuffer, Buffer, Option) { let Self { - data_type, + dtype, offsets, values, validity, } = self; - (data_type, offsets, values, validity) + (dtype, offsets, values, validity) } /// Try to convert this `Utf8Array` to a `MutableUtf8Array` @@ -260,19 +260,14 @@ impl Utf8Array { match bitmap.into_mut() { // SAFETY: invariants are preserved Left(bitmap) => Left(unsafe { - Utf8Array::new_unchecked( - self.data_type, - self.offsets, - self.values, - Some(bitmap), - ) + Utf8Array::new_unchecked(self.dtype, self.offsets, self.values, Some(bitmap)) }), Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { (Left(values), Left(offsets)) => { // SAFETY: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( - self.data_type, + self.dtype, offsets, values, Some(mutable_bitmap.into()), @@ -283,7 +278,7 @@ impl Utf8Array { // SAFETY: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( - self.data_type, + self.dtype, offsets.into(), values, Some(mutable_bitmap.into()), @@ -294,7 +289,7 @@ impl Utf8Array { // SAFETY: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( - self.data_type, + self.dtype, offsets, values.into(), Some(mutable_bitmap.into()), @@ -303,7 +298,7 @@ impl Utf8Array { }, (Right(values), Right(offsets)) => Right(unsafe { MutableUtf8Array::new_unchecked( - self.data_type, + self.dtype, offsets, values, Some(mutable_bitmap), @@ -314,16 +309,16 @@ impl Utf8Array { } else { match (self.values.into_mut(), self.offsets.into_mut()) { (Left(values), Left(offsets)) => { - Left(unsafe { Utf8Array::new_unchecked(self.data_type, offsets, values, None) }) + Left(unsafe { Utf8Array::new_unchecked(self.dtype, offsets, values, None) }) }, (Left(values), Right(offsets)) => Left(unsafe { - Utf8Array::new_unchecked(self.data_type, offsets.into(), values, None) + Utf8Array::new_unchecked(self.dtype, offsets.into(), values, None) }), (Right(values), Left(offsets)) => Left(unsafe { - Utf8Array::new_unchecked(self.data_type, offsets, values.into(), None) + Utf8Array::new_unchecked(self.dtype, offsets, values.into(), None) }), (Right(values), Right(offsets)) => Right(unsafe { - MutableUtf8Array::new_unchecked(self.data_type, offsets, values, None) + MutableUtf8Array::new_unchecked(self.dtype, offsets, values, None) }), } } @@ -333,15 +328,15 @@ impl Utf8Array { /// /// The array is guaranteed to have no elements nor validity. #[inline] - pub fn new_empty(data_type: ArrowDataType) -> Self { - unsafe { Self::new_unchecked(data_type, OffsetsBuffer::new(), Buffer::new(), None) } + pub fn new_empty(dtype: ArrowDataType) -> Self { + unsafe { Self::new_unchecked(dtype, OffsetsBuffer::new(), Buffer::new(), None) } } /// Returns a new [`Utf8Array`] whose all slots are null / `None`. #[inline] - pub fn new_null(data_type: ArrowDataType, length: usize) -> Self { + pub fn new_null(dtype: ArrowDataType, length: usize) -> Self { Self::new( - data_type, + dtype, Offsets::new_zeroed(length).into(), Buffer::new(), Some(Bitmap::new_zeroed(length)), @@ -349,7 +344,7 @@ impl Utf8Array { } /// Returns a default [`ArrowDataType`] of this array, which depends on the generic parameter `O`: `DataType::Utf8` or `DataType::LargeUtf8` - pub fn default_data_type() -> ArrowDataType { + pub fn default_dtype() -> ArrowDataType { if O::IS_LARGE { ArrowDataType::LargeUtf8 } else { @@ -363,7 +358,7 @@ impl Utf8Array { /// This function panics (in debug mode only) iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. /// /// # Safety /// This function is unsound iff: @@ -371,7 +366,7 @@ impl Utf8Array { /// # Implementation /// This function is `O(1)` pub unsafe fn new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, @@ -387,12 +382,12 @@ impl Utf8Array { "validity mask length must match the number of values" ); debug_assert!( - data_type.to_physical_type() == Self::default_data_type().to_physical_type(), + dtype.to_physical_type() == Self::default_dtype().to_physical_type(), "Utf8Array can only be initialized with DataType::Utf8 or DataType::LargeUtf8" ); Self { - data_type, + dtype, offsets, values, validity, @@ -404,17 +399,17 @@ impl Utf8Array { /// This function panics iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. /// * The `values` between two consecutive `offsets` are not valid utf8 /// # Implementation /// This function is `O(N)` - checking utf8 is `O(N)` pub fn new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: OffsetsBuffer, values: Buffer, validity: Option, ) -> Self { - Self::try_new(data_type, offsets, values, validity).unwrap() + Self::try_new(dtype, offsets, values, validity).unwrap() } /// Returns a (non-null) [`Utf8Array`] created from a [`TrustedLen`] of `&str`. @@ -497,7 +492,7 @@ impl Utf8Array { pub fn to_binary(&self) -> BinaryArray { unsafe { BinaryArray::new_unchecked( - BinaryArray::::default_data_type(), + BinaryArray::::default_dtype(), self.offsets.clone(), self.values.clone(), self.validity.clone(), @@ -518,13 +513,13 @@ impl Splitable for Utf8Array { ( Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: lhs_offsets, values: self.values.clone(), validity: lhs_validity, }, Self { - data_type: self.data_type.clone(), + dtype: self.dtype.clone(), offsets: rhs_offsets, values: self.values.clone(), validity: rhs_validity, @@ -560,11 +555,11 @@ unsafe impl GenericBinaryArray for Utf8Array { impl Default for Utf8Array { fn default() -> Self { - let data_type = if O::IS_LARGE { + let dtype = if O::IS_LARGE { ArrowDataType::LargeUtf8 } else { ArrowDataType::Utf8 }; - Utf8Array::new(data_type, Default::default(), Default::default(), None) + Utf8Array::new(dtype, Default::default(), Default::default(), None) } } diff --git a/crates/polars-arrow/src/array/utf8/mutable.rs b/crates/polars-arrow/src/array/utf8/mutable.rs index ef9a5e8527b7..570e795542ff 100644 --- a/crates/polars-arrow/src/array/utf8/mutable.rs +++ b/crates/polars-arrow/src/array/utf8/mutable.rs @@ -51,17 +51,17 @@ impl MutableUtf8Array { /// This function returns an error iff: /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. /// * The `values` between two consecutive `offsets` are not valid utf8 /// # Implementation /// This function is `O(N)` - checking utf8 is `O(N)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, validity: Option, ) -> PolarsResult { - let values = MutableUtf8ValuesArray::try_new(data_type, offsets, values)?; + let values = MutableUtf8ValuesArray::try_new(dtype, offsets, values)?; if validity .as_ref() @@ -82,12 +82,12 @@ impl MutableUtf8Array { /// * The `offsets` and `values` are inconsistent /// * The validity is not `None` and its length is different from `offsets`'s length minus one. pub unsafe fn new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, validity: Option, ) -> Self { - let values = MutableUtf8ValuesArray::new_unchecked(data_type, offsets, values); + let values = MutableUtf8ValuesArray::new_unchecked(dtype, offsets, values); if let Some(ref validity) = validity { assert_eq!(values.len(), validity.len()); } @@ -100,8 +100,8 @@ impl MutableUtf8Array { Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) } - fn default_data_type() -> ArrowDataType { - Utf8Array::::default_data_type() + fn default_dtype() -> ArrowDataType { + Utf8Array::::default_dtype() } /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots. @@ -198,8 +198,8 @@ impl MutableUtf8Array { /// Extract the low-end APIs from the [`MutableUtf8Array`]. pub fn into_data(self) -> (ArrowDataType, Offsets, Vec, Option) { - let (data_type, offsets, values) = self.values.into_inner(); - (data_type, offsets, values, self.validity) + let (dtype, offsets, values) = self.values.into_inner(); + (dtype, offsets, values, self.validity) } /// Returns an iterator of `&str` @@ -260,7 +260,7 @@ impl MutableArray for MutableUtf8Array { array.arced() } - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { if O::IS_LARGE { &ArrowDataType::LargeUtf8 } else { @@ -391,7 +391,7 @@ impl MutableUtf8Array { let (validity, offsets, values) = trusted_len_unzip(iterator); // soundness: P is `str` - Self::new_unchecked(Self::default_data_type(), offsets, values, validity) + Self::new_unchecked(Self::default_dtype(), offsets, values, validity) } /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. @@ -462,7 +462,7 @@ impl MutableUtf8Array { // soundness: P is `str` Ok(Self::new_unchecked( - Self::default_data_type(), + Self::default_dtype(), offsets, values, validity, @@ -522,9 +522,8 @@ impl> TryPush> for MutableUtf8Array { Some(value) => { self.values.try_push(value.as_ref())?; - match &mut self.validity { - Some(validity) => validity.push(true), - None => {}, + if let Some(validity) = &mut self.validity { + validity.push(true) } }, None => { diff --git a/crates/polars-arrow/src/array/utf8/mutable_values.rs b/crates/polars-arrow/src/array/utf8/mutable_values.rs index ce3c2f71f20c..ec362a40a8db 100644 --- a/crates/polars-arrow/src/array/utf8/mutable_values.rs +++ b/crates/polars-arrow/src/array/utf8/mutable_values.rs @@ -15,7 +15,7 @@ use crate::trusted_len::TrustedLen; /// from [`MutableUtf8Array`] in that it builds non-null [`Utf8Array`]. #[derive(Debug, Clone)] pub struct MutableUtf8ValuesArray { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, } @@ -27,7 +27,7 @@ impl From> for Utf8Array { // `Utf8Array` can be safely created from `MutableUtf8ValuesArray` without checks. unsafe { Utf8Array::::new_unchecked( - other.data_type, + other.dtype, other.offsets.into(), other.values.into(), None, @@ -41,7 +41,7 @@ impl From> for MutableUtf8Array { // SAFETY: // `MutableUtf8ValuesArray` has the same invariants as `MutableUtf8Array` unsafe { - MutableUtf8Array::::new_unchecked(other.data_type, other.offsets, other.values, None) + MutableUtf8Array::::new_unchecked(other.dtype, other.offsets, other.values, None) } } } @@ -56,7 +56,7 @@ impl MutableUtf8ValuesArray { /// Returns an empty [`MutableUtf8ValuesArray`]. pub fn new() -> Self { Self { - data_type: Self::default_data_type(), + dtype: Self::default_dtype(), offsets: Offsets::new(), values: Vec::::new(), } @@ -67,22 +67,22 @@ impl MutableUtf8ValuesArray { /// # Errors /// This function returns an error iff: /// * The last offset is not equal to the values' length. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. /// * The `values` between two consecutive `offsets` are not valid utf8 /// # Implementation /// This function is `O(N)` - checking utf8 is `O(N)` pub fn try_new( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, ) -> PolarsResult { try_check_utf8(&offsets, &values)?; - if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { polars_bail!(ComputeError: "MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8") } Ok(Self { - data_type, + dtype, offsets, values, }) @@ -93,7 +93,7 @@ impl MutableUtf8ValuesArray { /// # Panic /// This function does not panic iff: /// * The last offset is equal to the values' length. - /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. + /// * The `dtype`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. /// /// # Safety /// This function is safe iff: @@ -102,19 +102,19 @@ impl MutableUtf8ValuesArray { /// # Implementation /// This function is `O(1)` pub unsafe fn new_unchecked( - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Vec, ) -> Self { try_check_offsets_bounds(&offsets, values.len()) .expect("The length of the values must be equal to the last offset value"); - if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + if dtype.to_physical_type() != Self::default_dtype().to_physical_type() { panic!("MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8") } Self { - data_type, + dtype, offsets, values, } @@ -122,8 +122,8 @@ impl MutableUtf8ValuesArray { /// Returns the default [`ArrowDataType`] of this container: [`ArrowDataType::Utf8`] or [`ArrowDataType::LargeUtf8`] /// depending on the generic [`Offset`]. - pub fn default_data_type() -> ArrowDataType { - Utf8Array::::default_data_type() + pub fn default_dtype() -> ArrowDataType { + Utf8Array::::default_dtype() } /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items. @@ -134,7 +134,7 @@ impl MutableUtf8ValuesArray { /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items and values. pub fn with_capacities(capacity: usize, values: usize) -> Self { Self { - data_type: Self::default_data_type(), + dtype: Self::default_dtype(), offsets: Offsets::::with_capacity(capacity), values: Vec::::with_capacity(values), } @@ -229,7 +229,7 @@ impl MutableUtf8ValuesArray { /// Extract the low-end APIs from the [`MutableUtf8ValuesArray`]. pub fn into_inner(self) -> (ArrowDataType, Offsets, Vec) { - (self.data_type, self.offsets, self.values) + (self.dtype, self.offsets, self.values) } } @@ -252,8 +252,8 @@ impl MutableArray for MutableUtf8ValuesArray { array.arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { @@ -282,7 +282,7 @@ impl> FromIterator

for MutableUtf8ValuesArray { fn from_iter>(iter: I) -> Self { let (offsets, values) = values_iter(iter.into_iter().map(StrAsBytes)); // soundness: T: AsRef and offsets are monotonically increasing - unsafe { Self::new_unchecked(Self::default_data_type(), offsets, values) } + unsafe { Self::new_unchecked(Self::default_dtype(), offsets, values) } } } @@ -349,7 +349,7 @@ impl MutableUtf8ValuesArray { let (offsets, values) = trusted_len_values_iter(iterator); // soundness: P is `str` and offsets are monotonically increasing - Self::new_unchecked(Self::default_data_type(), offsets, values) + Self::new_unchecked(Self::default_dtype(), offsets, values) } /// Returns a new [`MutableUtf8ValuesArray`] from an iterator. diff --git a/crates/polars-arrow/src/array/values.rs b/crates/polars-arrow/src/array/values.rs index 9864e4f4c129..197d97f167eb 100644 --- a/crates/polars-arrow/src/array/values.rs +++ b/crates/polars-arrow/src/array/values.rs @@ -54,7 +54,7 @@ impl ValueSize for BinaryArray { impl ValueSize for ArrayRef { fn get_values_size(&self) -> usize { - match self.data_type() { + match self.dtype() { ArrowDataType::LargeUtf8 => self .as_any() .downcast_ref::>() diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs index 9e5ac502e6b5..a3edb658be4e 100644 --- a/crates/polars-arrow/src/bitmap/bitmap_ops.rs +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -300,6 +300,22 @@ pub fn intersects_with_mut(lhs: &MutableBitmap, rhs: &MutableBitmap) -> bool { ) } +pub fn num_edges(lhs: &Bitmap) -> usize { + if lhs.is_empty() { + return 0; + } + + // @TODO: If is probably quite inefficient to do it like this because now either one is not + // aligned. Maybe, we can implement a smarter way to do this. + binary_fold( + &unsafe { lhs.clone().sliced_unchecked(0, lhs.len() - 1) }, + &unsafe { lhs.clone().sliced_unchecked(1, lhs.len() - 1) }, + |l, r| (l ^ r).count_ones() as usize, + 0, + |acc, v| acc + v, + ) +} + /// Compute `out[i] = if selector[i] { truthy[i] } else { falsy }`. pub fn select_constant(selector: &Bitmap, truthy: &Bitmap, falsy: bool) -> Bitmap { let falsy_mask: u64 = if falsy { diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index 4b52045afa9f..6ad76a07b639 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -555,6 +555,11 @@ impl Bitmap { pub fn select_constant(&self, truthy: &Self, falsy: bool) -> Self { super::bitmap_ops::select_constant(self, truthy, falsy) } + + /// Calculates the number of edges from `0 -> 1` and `1 -> 0`. + pub fn num_edges(&self) -> usize { + super::bitmap_ops::num_edges(self) + } } impl> From

for Bitmap { diff --git a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs index dc388f1d41b5..f3083ad0b141 100644 --- a/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs +++ b/crates/polars-arrow/src/bitmap/utils/slice_iterator.rs @@ -9,7 +9,8 @@ enum State { Finished, } -/// Iterator over a bitmap that returns slices of set regions +/// Iterator over a bitmap that returns slices of set regions. +/// /// This is the most efficient method to extract slices of values from arrays /// with a validity bitmap. /// For example, the bitmap `00101111` returns `[(0,4), (6,1)]` diff --git a/crates/polars-arrow/src/compute/aggregate/memory.rs b/crates/polars-arrow/src/compute/aggregate/memory.rs index 8b59503b93e7..bd4ba7ab6384 100644 --- a/crates/polars-arrow/src/compute/aggregate/memory.rs +++ b/crates/polars-arrow/src/compute/aggregate/memory.rs @@ -42,7 +42,7 @@ fn binview_size(array: &BinaryViewArrayGeneric) -> usiz /// FFI buffers are included in this estimation. pub fn estimated_bytes_size(array: &dyn Array) -> usize { use PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => 0, Boolean => { let array = array.as_any().downcast_ref::().unwrap(); diff --git a/crates/polars-arrow/src/compute/aggregate/sum.rs b/crates/polars-arrow/src/compute/aggregate/sum.rs index 8ba9714f9521..9fbed5f8b1b6 100644 --- a/crates/polars-arrow/src/compute/aggregate/sum.rs +++ b/crates/polars-arrow/src/compute/aggregate/sum.rs @@ -102,9 +102,9 @@ where } } -/// Whether [`sum`] supports `data_type` -pub fn can_sum(data_type: &ArrowDataType) -> bool { - if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() { +/// Whether [`sum`] supports `dtype` +pub fn can_sum(dtype: &ArrowDataType) -> bool { + if let PhysicalType::Primitive(primitive) = dtype.to_physical_type() { use PrimitiveType::*; matches!( primitive, @@ -120,11 +120,11 @@ pub fn can_sum(data_type: &ArrowDataType) -> bool { /// # Error /// Errors iff the operation is not supported. pub fn sum(array: &dyn Array) -> PolarsResult> { - Ok(match array.data_type().to_physical_type() { + Ok(match array.dtype().to_physical_type() { PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { - let data_type = array.data_type().clone(); + let dtype = array.dtype().clone(); let array = array.as_any().downcast_ref().unwrap(); - Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array))) + Box::new(PrimitiveScalar::new(dtype, sum_primitive::<$T>(array))) }), _ => { unimplemented!() diff --git a/crates/polars-arrow/src/compute/arity.rs b/crates/polars-arrow/src/compute/arity.rs index e590e7b1974b..c5b397f6faac 100644 --- a/crates/polars-arrow/src/compute/arity.rs +++ b/crates/polars-arrow/src/compute/arity.rs @@ -8,21 +8,17 @@ use crate::bitmap::{Bitmap, MutableBitmap}; use crate::datatypes::ArrowDataType; use crate::types::NativeType; -/// Applies an unary and infallible function to a [`PrimitiveArray`]. This is the -/// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits -/// of a vectorized operation outweighs the cost of branching nulls and -/// non-nulls. +/// Applies an unary and infallible function to a [`PrimitiveArray`]. +/// +/// This is the /// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits +/// of a vectorized operation outweighs the cost of branching nulls and non-nulls. /// /// # Implementation /// This will apply the function for all values, including those on null slots. /// This implies that the operation must be infallible for any value of the /// corresponding type or this function may panic. #[inline] -pub fn unary( - array: &PrimitiveArray, - op: F, - data_type: ArrowDataType, -) -> PrimitiveArray +pub fn unary(array: &PrimitiveArray, op: F, dtype: ArrowDataType) -> PrimitiveArray where I: NativeType, O: NativeType, @@ -30,7 +26,7 @@ where { let values = array.values().iter().map(|v| op(*v)).collect::>(); - PrimitiveArray::::new(data_type, values.into(), array.validity().cloned()) + PrimitiveArray::::new(dtype, values.into(), array.validity().cloned()) } /// Version of unary that checks for errors in the closure used to create the @@ -38,7 +34,7 @@ where pub fn try_unary( array: &PrimitiveArray, op: F, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PolarsResult> where I: NativeType, @@ -53,7 +49,7 @@ where .into(); Ok(PrimitiveArray::::new( - data_type, + dtype, values, array.validity().cloned(), )) @@ -64,7 +60,7 @@ where pub fn unary_with_bitmap( array: &PrimitiveArray, op: F, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> (PrimitiveArray, Bitmap) where I: NativeType, @@ -85,7 +81,7 @@ where .into(); ( - PrimitiveArray::::new(data_type, values, array.validity().cloned()), + PrimitiveArray::::new(dtype, values, array.validity().cloned()), mut_bitmap.into(), ) } @@ -96,7 +92,7 @@ where pub fn unary_checked( array: &PrimitiveArray, op: F, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PrimitiveArray where I: NativeType, @@ -128,14 +124,17 @@ where let bitmap: Bitmap = mut_bitmap.into(); let validity = combine_validities_and(array.validity(), Some(&bitmap)); - PrimitiveArray::::new(data_type, values, validity) + PrimitiveArray::::new(dtype, values, validity) } -/// Applies a binary operations to two primitive arrays. This is the fastest -/// way to perform an operation on two primitive array when the benefits of a +/// Applies a binary operations to two primitive arrays. +/// +/// This is the fastest way to perform an operation on two primitive array when the benefits of a /// vectorized operation outweighs the cost of branching nulls and non-nulls. +/// /// # Errors /// This function errors iff the arrays have a different length. +/// /// # Implementation /// This will apply the function for all values, including those on null slots. /// This implies that the operation must be infallible for any value of the @@ -148,7 +147,7 @@ where pub fn binary( lhs: &PrimitiveArray, rhs: &PrimitiveArray, - data_type: ArrowDataType, + dtype: ArrowDataType, op: F, ) -> PrimitiveArray where @@ -168,7 +167,7 @@ where .collect::>() .into(); - PrimitiveArray::::new(data_type, values, validity) + PrimitiveArray::::new(dtype, values, validity) } /// Version of binary that checks for errors in the closure used to create the @@ -176,7 +175,7 @@ where pub fn try_binary( lhs: &PrimitiveArray, rhs: &PrimitiveArray, - data_type: ArrowDataType, + dtype: ArrowDataType, op: F, ) -> PolarsResult> where @@ -196,7 +195,7 @@ where .collect::>>()? .into(); - Ok(PrimitiveArray::::new(data_type, values, validity)) + Ok(PrimitiveArray::::new(dtype, values, validity)) } /// Version of binary that returns an array and bitmap. Used when working with @@ -204,7 +203,7 @@ where pub fn binary_with_bitmap( lhs: &PrimitiveArray, rhs: &PrimitiveArray, - data_type: ArrowDataType, + dtype: ArrowDataType, op: F, ) -> (PrimitiveArray, Bitmap) where @@ -231,7 +230,7 @@ where .into(); ( - PrimitiveArray::::new(data_type, values, validity), + PrimitiveArray::::new(dtype, values, validity), mut_bitmap.into(), ) } @@ -242,7 +241,7 @@ where pub fn binary_checked( lhs: &PrimitiveArray, rhs: &PrimitiveArray, - data_type: ArrowDataType, + dtype: ArrowDataType, op: F, ) -> PrimitiveArray where @@ -280,5 +279,5 @@ where // as Null let validity = combine_validities_and(validity.as_ref(), Some(&bitmap)); - PrimitiveArray::::new(data_type, values, validity) + PrimitiveArray::::new(dtype, values, validity) } diff --git a/crates/polars-arrow/src/compute/bitwise.rs b/crates/polars-arrow/src/compute/bitwise.rs index 37c26542b848..1762fb430e58 100644 --- a/crates/polars-arrow/src/compute/bitwise.rs +++ b/crates/polars-arrow/src/compute/bitwise.rs @@ -12,7 +12,7 @@ pub fn or(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeType + BitOr, { - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a | b) + binary(lhs, rhs, lhs.dtype().clone(), |a, b| a | b) } /// Performs `XOR` operation between two [`PrimitiveArray`]s. @@ -22,7 +22,7 @@ pub fn xor(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArra where T: NativeType + BitXor, { - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a ^ b) + binary(lhs, rhs, lhs.dtype().clone(), |a, b| a ^ b) } /// Performs `AND` operation on two [`PrimitiveArray`]s. @@ -32,7 +32,7 @@ pub fn and(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArra where T: NativeType + BitAnd, { - binary(lhs, rhs, lhs.data_type().clone(), |a, b| a & b) + binary(lhs, rhs, lhs.dtype().clone(), |a, b| a & b) } /// Returns a new [`PrimitiveArray`] with the bitwise `not`. @@ -41,7 +41,7 @@ where T: NativeType + Not, { let op = move |a: T| !a; - unary(array, op, array.data_type().clone()) + unary(array, op, array.dtype().clone()) } /// Performs `OR` operation between a [`PrimitiveArray`] and scalar. @@ -51,7 +51,7 @@ pub fn or_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeType + BitOr, { - unary(lhs, |a| a | *rhs, lhs.data_type().clone()) + unary(lhs, |a| a | *rhs, lhs.dtype().clone()) } /// Performs `XOR` operation between a [`PrimitiveArray`] and scalar. @@ -61,7 +61,7 @@ pub fn xor_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeType + BitXor, { - unary(lhs, |a| a ^ *rhs, lhs.data_type().clone()) + unary(lhs, |a| a ^ *rhs, lhs.dtype().clone()) } /// Performs `AND` operation between a [`PrimitiveArray`] and scalar. @@ -71,5 +71,5 @@ pub fn and_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeType + BitAnd, { - unary(lhs, |a| a & *rhs, lhs.data_type().clone()) + unary(lhs, |a| a & *rhs, lhs.dtype().clone()) } diff --git a/crates/polars-arrow/src/compute/cast/binary_to.rs b/crates/polars-arrow/src/compute/cast/binary_to.rs index 72c5f20922b0..e14a03040522 100644 --- a/crates/polars-arrow/src/compute/cast/binary_to.rs +++ b/crates/polars-arrow/src/compute/cast/binary_to.rs @@ -53,11 +53,11 @@ impl Parse for f64 { /// Conversion of binary pub fn binary_to_large_binary( from: &BinaryArray, - to_data_type: ArrowDataType, + to_dtype: ArrowDataType, ) -> BinaryArray { let values = from.values().clone(); BinaryArray::::new( - to_data_type, + to_dtype, from.offsets().into(), values, from.validity().cloned(), @@ -67,12 +67,12 @@ pub fn binary_to_large_binary( /// Conversion of binary pub fn binary_large_to_binary( from: &BinaryArray, - to_data_type: ArrowDataType, + to_dtype: ArrowDataType, ) -> PolarsResult> { let values = from.values().clone(); let offsets = from.offsets().try_into()?; Ok(BinaryArray::::new( - to_data_type, + to_dtype, offsets, values, from.validity().cloned(), @@ -82,10 +82,10 @@ pub fn binary_large_to_binary( /// Conversion to utf8 pub fn binary_to_utf8( from: &BinaryArray, - to_data_type: ArrowDataType, + to_dtype: ArrowDataType, ) -> PolarsResult> { Utf8Array::::try_new( - to_data_type, + to_dtype, from.offsets().clone(), from.values().clone(), from.validity().cloned(), @@ -97,12 +97,12 @@ pub fn binary_to_utf8( /// This function errors if the values are not valid utf8 pub fn binary_to_large_utf8( from: &BinaryArray, - to_data_type: ArrowDataType, + to_dtype: ArrowDataType, ) -> PolarsResult> { let values = from.values().clone(); let offsets = from.offsets().into(); - Utf8Array::::try_new(to_data_type, offsets, values, from.validity().cloned()) + Utf8Array::::try_new(to_dtype, offsets, values, from.validity().cloned()) } /// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. @@ -169,16 +169,11 @@ fn fixed_size_to_offsets(values_len: usize, fixed_size: usize) -> Off /// Conversion of `FixedSizeBinary` to `Binary`. pub fn fixed_size_binary_binary( from: &FixedSizeBinaryArray, - to_data_type: ArrowDataType, + to_dtype: ArrowDataType, ) -> BinaryArray { let values = from.values().clone(); let offsets = fixed_size_to_offsets(values.len(), from.size()); - BinaryArray::::new( - to_data_type, - offsets.into(), - values, - from.validity().cloned(), - ) + BinaryArray::::new(to_dtype, offsets.into(), values, from.validity().cloned()) } pub fn fixed_size_binary_to_binview(from: &FixedSizeBinaryArray) -> BinaryViewArray { @@ -248,14 +243,11 @@ pub fn fixed_size_binary_to_binview(from: &FixedSizeBinaryArray) -> BinaryViewAr } /// Conversion of binary -pub fn binary_to_list( - from: &BinaryArray, - to_data_type: ArrowDataType, -) -> ListArray { +pub fn binary_to_list(from: &BinaryArray, to_dtype: ArrowDataType) -> ListArray { let values = from.values().clone(); let values = PrimitiveArray::new(ArrowDataType::UInt8, values, None); ListArray::::new( - to_data_type, + to_dtype, from.offsets().clone(), values.boxed(), from.validity().cloned(), diff --git a/crates/polars-arrow/src/compute/cast/binview_to.rs b/crates/polars-arrow/src/compute/cast/binview_to.rs index 5a12a14aaca7..406fdcc14e80 100644 --- a/crates/polars-arrow/src/compute/cast/binview_to.rs +++ b/crates/polars-arrow/src/compute/cast/binview_to.rs @@ -51,7 +51,7 @@ pub fn utf8view_to_utf8(array: &Utf8ViewArray) -> Utf8Array { let array = array.to_binview(); let out = view_to_binary::(&array); - let dtype = Utf8Array::::default_data_type(); + let dtype = Utf8Array::::default_dtype(); unsafe { Utf8Array::new_unchecked( dtype, diff --git a/crates/polars-arrow/src/compute/cast/decimal_to.rs b/crates/polars-arrow/src/compute/cast/decimal_to.rs index dd2f29e1a443..babd17158143 100644 --- a/crates/polars-arrow/src/compute/cast/decimal_to.rs +++ b/crates/polars-arrow/src/compute/cast/decimal_to.rs @@ -37,7 +37,7 @@ pub fn decimal_to_decimal( to_scale: usize, ) -> PrimitiveArray { let (from_precision, from_scale) = - if let ArrowDataType::Decimal(p, s) = from.data_type().to_logical_type() { + if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { (*p, *s) } else { panic!("internal error: i128 is always a decimal") @@ -86,7 +86,7 @@ where T: NativeType + Float, f64: AsPrimitive, { - let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.data_type().to_logical_type() { + let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { (*p, *s) } else { panic!("internal error: i128 is always a decimal") @@ -116,7 +116,7 @@ pub fn decimal_to_integer(from: &PrimitiveArray) -> PrimitiveArray where T: NativeType + NumCast, { - let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.data_type().to_logical_type() { + let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { (*p, *s) } else { panic!("internal error: i128 is always a decimal") @@ -139,7 +139,7 @@ where /// Returns a [`Utf8Array`] where every element is the utf8 representation of the decimal. #[cfg(feature = "dtype-decimal")] pub(super) fn decimal_to_utf8view(from: &PrimitiveArray) -> Utf8ViewArray { - let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.data_type().to_logical_type() { + let (_, from_scale) = if let ArrowDataType::Decimal(p, s) = from.dtype().to_logical_type() { (*p, *s) } else { panic!("internal error: i128 is always a decimal") diff --git a/crates/polars-arrow/src/compute/cast/dictionary_to.rs b/crates/polars-arrow/src/compute/cast/dictionary_to.rs index 134c9af7991f..d67a116ca0de 100644 --- a/crates/polars-arrow/src/compute/cast/dictionary_to.rs +++ b/crates/polars-arrow/src/compute/cast/dictionary_to.rs @@ -39,7 +39,7 @@ pub fn dictionary_to_dictionary_values( assert_eq!(values.len(), length); // this is guaranteed by `cast` unsafe { - DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + DictionaryArray::try_new_unchecked(from.dtype().clone(), keys.clone(), values.clone()) } } @@ -62,7 +62,7 @@ pub fn wrapping_dictionary_to_dictionary_values( )?; assert_eq!(values.len(), length); // this is guaranteed by `cast` unsafe { - DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + DictionaryArray::try_new_unchecked(from.dtype().clone(), keys.clone(), values.clone()) } } @@ -87,13 +87,10 @@ where if casted_keys.null_count() > keys.null_count() { polars_bail!(ComputeError: "overflow") } else { - let data_type = ArrowDataType::Dictionary( - K2::KEY_TYPE, - Box::new(values.data_type().clone()), - is_ordered, - ); + let dtype = + ArrowDataType::Dictionary(K2::KEY_TYPE, Box::new(values.dtype().clone()), is_ordered); // SAFETY: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` - unsafe { DictionaryArray::try_new_unchecked(data_type, casted_keys, values.clone()) } + unsafe { DictionaryArray::try_new_unchecked(dtype, casted_keys, values.clone()) } } } @@ -114,13 +111,10 @@ where if casted_keys.null_count() > keys.null_count() { polars_bail!(ComputeError: "overflow") } else { - let data_type = ArrowDataType::Dictionary( - K2::KEY_TYPE, - Box::new(values.data_type().clone()), - is_ordered, - ); + let dtype = + ArrowDataType::Dictionary(K2::KEY_TYPE, Box::new(values.dtype().clone()), is_ordered); // some of the values may not fit in `usize` and thus this needs to be checked - DictionaryArray::try_new(data_type, casted_keys, values.clone()) + DictionaryArray::try_new(dtype, casted_keys, values.clone()) } } diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 0afa67ec875a..9193abe8d476 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -89,7 +89,7 @@ fn cast_struct( let new_values = values .iter() .zip(fields) - .map(|(arr, field)| cast(arr.as_ref(), field.data_type(), options)) + .map(|(arr, field)| cast(arr.as_ref(), field.dtype(), options)) .collect::>>()?; Ok(StructArray::new( @@ -190,7 +190,7 @@ fn cast_list_to_fixed_size_list( list.offsets().first().to_usize(), list.offsets().range().to_usize(), ); - cast(sliced_values.as_ref(), inner.data_type(), options)? + cast(sliced_values.as_ref(), inner.dtype(), options)? }, } } else { @@ -230,7 +230,7 @@ fn cast_list_to_fixed_size_list( crate::compute::take::take_unchecked(list.values().as_ref(), &indices.freeze()) }; - cast(take_values.as_ref(), inner.data_type(), options)? + cast(take_values.as_ref(), inner.dtype(), options)? }; FixedSizeListArray::try_new( ArrowDataType::FixedSizeList(Box::new(inner.clone()), size), @@ -279,7 +279,7 @@ pub fn cast( options: CastOptionsImpl, ) -> PolarsResult> { use ArrowDataType::*; - let from_type = array.data_type(); + let from_type = array.dtype(); // clone array if types are the same if from_type == to_type { @@ -350,7 +350,7 @@ pub fn cast( Int64 => binview_to_primitive_dyn::(array, to_type, options), Float32 => binview_to_primitive_dyn::(array, to_type, options), Float64 => binview_to_primitive_dyn::(array, to_type, options), - LargeList(inner) if matches!(inner.data_type, ArrowDataType::UInt8) => { + LargeList(inner) if matches!(inner.dtype, ArrowDataType::UInt8) => { let bin_array = view_to_binary::(array.as_any().downcast_ref().unwrap()); Ok(binary_to_list(&bin_array, to_type.clone()).boxed()) }, @@ -371,7 +371,7 @@ pub fn cast( (_, List(to)) => { // cast primitive to list's primitive - let values = cast(array, &to.data_type, options)?; + let values = cast(array, &to.dtype, options)?; // create offsets, where if array.len() = 2, we have [0,1,2] let offsets = (0..=array.len() as i32).collect::>(); // SAFETY: offsets _are_ monotonically increasing @@ -384,7 +384,7 @@ pub fn cast( (_, LargeList(to)) if from_type != &LargeBinary => { // cast primitive to list's primitive - let values = cast(array, &to.data_type, options)?; + let values = cast(array, &to.dtype, options)?; // create offsets, where if array.len() = 2, we have [0,1,2] let offsets = (0..=array.len() as i64).collect::>(); // SAFETY: offsets _are_ monotonically increasing diff --git a/crates/polars-arrow/src/compute/cast/primitive_to.rs b/crates/polars-arrow/src/compute/cast/primitive_to.rs index 98df4b7f6b83..6fa9b9fb01fd 100644 --- a/crates/polars-arrow/src/compute/cast/primitive_to.rs +++ b/crates/polars-arrow/src/compute/cast/primitive_to.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use num_traits::{AsPrimitive, Float, ToPrimitive}; use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; use super::CastOptionsImpl; use crate::array::*; @@ -122,7 +123,7 @@ pub(super) fn primitive_to_utf8( let (values, offsets) = primitive_to_values_and_offsets(from); unsafe { Utf8Array::::new_unchecked( - Utf8Array::::default_data_type(), + Utf8Array::::default_dtype(), offsets.into(), values.into(), from.validity().cloned(), @@ -316,7 +317,7 @@ pub fn primitive_to_dictionary( ) -> PolarsResult> { let iter = from.iter().map(|x| x.copied()); let mut array = MutableDictionaryArray::::try_empty(MutablePrimitiveArray::::from( - from.data_type().clone(), + from.dtype().clone(), ))?; array.reserve(from.len()); array.try_extend(iter)?; @@ -324,16 +325,6 @@ pub fn primitive_to_dictionary( Ok(array.into()) } -/// Get the time unit as a multiple of a second -const fn time_unit_multiple(unit: TimeUnit) -> i64 { - match unit { - TimeUnit::Second => 1, - TimeUnit::Millisecond => MILLISECONDS, - TimeUnit::Microsecond => MICROSECONDS, - TimeUnit::Nanosecond => NANOSECONDS, - } -} - /// Conversion of dates pub fn date32_to_date64(from: &PrimitiveArray) -> PrimitiveArray { unary( @@ -444,7 +435,7 @@ pub fn timestamp_to_timestamp( from: &PrimitiveArray, from_unit: TimeUnit, to_unit: TimeUnit, - tz: &Option, + tz: &Option, ) -> PrimitiveArray { let from_size = time_unit_multiple(from_unit); let to_size = time_unit_multiple(to_unit); diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index 85b478c43817..180ee516f92e 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -35,35 +35,32 @@ pub fn utf8_to_dictionary( /// Conversion of utf8 pub fn utf8_to_large_utf8(from: &Utf8Array) -> Utf8Array { - let data_type = Utf8Array::::default_data_type(); + let dtype = Utf8Array::::default_dtype(); let validity = from.validity().cloned(); let values = from.values().clone(); let offsets = from.offsets().into(); // SAFETY: sound because `values` fulfills the same invariants as `from.values()` - unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } + unsafe { Utf8Array::::new_unchecked(dtype, offsets, values, validity) } } /// Conversion of utf8 pub fn utf8_large_to_utf8(from: &Utf8Array) -> PolarsResult> { - let data_type = Utf8Array::::default_data_type(); + let dtype = Utf8Array::::default_dtype(); let validity = from.validity().cloned(); let values = from.values().clone(); let offsets = from.offsets().try_into()?; // SAFETY: sound because `values` fulfills the same invariants as `from.values()` - Ok(unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) }) + Ok(unsafe { Utf8Array::::new_unchecked(dtype, offsets, values, validity) }) } /// Conversion to binary -pub fn utf8_to_binary( - from: &Utf8Array, - to_data_type: ArrowDataType, -) -> BinaryArray { +pub fn utf8_to_binary(from: &Utf8Array, to_dtype: ArrowDataType) -> BinaryArray { // SAFETY: erasure of an invariant is always safe unsafe { BinaryArray::::new( - to_data_type, + to_dtype, from.offsets().clone(), from.values().clone(), from.validity().cloned(), diff --git a/crates/polars-arrow/src/compute/concatenate.rs b/crates/polars-arrow/src/compute/concatenate.rs index 17b28b65f682..1951cad9f4f5 100644 --- a/crates/polars-arrow/src/compute/concatenate.rs +++ b/crates/polars-arrow/src/compute/concatenate.rs @@ -1,18 +1,4 @@ //r Contains the concatenate kernel -//! -//! Example: -//! -//! ``` -//! use polars_arrow::array::Utf8Array; -//! use polars_arrow::compute::concatenate::concatenate; -//! -//! let arr = concatenate(&[ -//! &Utf8Array::::from_slice(["hello", "world"]), -//! &Utf8Array::::from_slice(["!"]), -//! ]).unwrap(); -//! assert_eq!(arr.len(), 3); -//! ``` - use polars_error::{polars_bail, PolarsResult}; use crate::array::growable::make_growable; @@ -27,7 +13,7 @@ pub fn concatenate(arrays: &[&dyn Array]) -> PolarsResult> { if arrays .iter() - .any(|array| array.data_type() != arrays[0].data_type()) + .any(|array| array.dtype() != arrays[0].dtype()) { polars_bail!(InvalidOperation: "It is not possible to concatenate arrays of different data types.") } diff --git a/crates/polars-arrow/src/compute/take/binary.rs b/crates/polars-arrow/src/compute/take/binary.rs index 8d2b971ced8f..576fbc8e4f37 100644 --- a/crates/polars-arrow/src/compute/take/binary.rs +++ b/crates/polars-arrow/src/compute/take/binary.rs @@ -25,7 +25,7 @@ pub unsafe fn take_unchecked( values: &BinaryArray, indices: &PrimitiveArray, ) -> BinaryArray { - let data_type = values.data_type().clone(); + let dtype = values.dtype().clone(); let indices_has_validity = indices.null_count() > 0; let values_has_validity = values.null_count() > 0; @@ -37,5 +37,5 @@ pub unsafe fn take_unchecked( (false, true) => take_indices_validity(values.offsets(), values.values(), indices), (true, true) => take_values_indices_validity(values, indices), }; - BinaryArray::::new_unchecked(data_type, offsets, values, validity) + BinaryArray::::new_unchecked(dtype, offsets, values, validity) } diff --git a/crates/polars-arrow/src/compute/take/binview.rs b/crates/polars-arrow/src/compute/take/binview.rs index 65ff633a080a..02b0272be873 100644 --- a/crates/polars-arrow/src/compute/take/binview.rs +++ b/crates/polars-arrow/src/compute/take/binview.rs @@ -12,7 +12,7 @@ pub(super) unsafe fn take_binview_unchecked( take_values_and_validity_unchecked(arr.views(), arr.validity(), indices); BinaryViewArray::new_unchecked_unknown_md( - arr.data_type().clone(), + arr.dtype().clone(), views.into(), arr.data_buffers().clone(), validity, diff --git a/crates/polars-arrow/src/compute/take/boolean.rs b/crates/polars-arrow/src/compute/take/boolean.rs index 3e6008d54652..745c7036c16b 100644 --- a/crates/polars-arrow/src/compute/take/boolean.rs +++ b/crates/polars-arrow/src/compute/take/boolean.rs @@ -63,7 +63,7 @@ pub unsafe fn take_unchecked( values: &BooleanArray, indices: &PrimitiveArray, ) -> BooleanArray { - let data_type = values.data_type().clone(); + let dtype = values.dtype().clone(); let indices_has_validity = indices.null_count() > 0; let values_has_validity = values.null_count() > 0; @@ -74,5 +74,5 @@ pub unsafe fn take_unchecked( (true, true) => take_values_indices_validity(values, indices), }; - BooleanArray::new(data_type, values, validity) + BooleanArray::new(dtype, values, validity) } diff --git a/crates/polars-arrow/src/compute/take/list.rs b/crates/polars-arrow/src/compute/take/list.rs index 547b738e0acb..36ca1f72131f 100644 --- a/crates/polars-arrow/src/compute/take/list.rs +++ b/crates/polars-arrow/src/compute/take/list.rs @@ -28,7 +28,7 @@ pub(super) unsafe fn take_unchecked( ) -> ListArray { // fast-path: all values to take are none if indices.null_count() == indices.len() { - return ListArray::::new_null(values.data_type().clone(), indices.len()); + return ListArray::::new_null(values.dtype().clone(), indices.len()); } let mut capacity = 0; diff --git a/crates/polars-arrow/src/compute/take/mod.rs b/crates/polars-arrow/src/compute/take/mod.rs index 34b62802dc12..aed14823af1e 100644 --- a/crates/polars-arrow/src/compute/take/mod.rs +++ b/crates/polars-arrow/src/compute/take/mod.rs @@ -40,12 +40,12 @@ use crate::with_match_primitive_type_full; /// Doesn't do bound checks pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box { if indices.len() == 0 { - return new_empty_array(values.data_type().clone()); + return new_empty_array(values.dtype().clone()); } use crate::datatypes::PhysicalType::*; - match values.data_type().to_physical_type() { - Null => Box::new(NullArray::new(values.data_type().clone(), indices.len())), + match values.dtype().to_physical_type() { + Null => Box::new(NullArray::new(values.dtype().clone(), indices.len())), Boolean => { let values = values.as_any().downcast_ref().unwrap(); Box::new(boolean::take_unchecked(values, indices)) diff --git a/crates/polars-arrow/src/compute/take/primitive.rs b/crates/polars-arrow/src/compute/take/primitive.rs index c8686201fdbb..8997323b5c15 100644 --- a/crates/polars-arrow/src/compute/take/primitive.rs +++ b/crates/polars-arrow/src/compute/take/primitive.rs @@ -76,5 +76,5 @@ pub unsafe fn take_primitive_unchecked( ) -> PrimitiveArray { let (values, validity) = take_values_and_validity_unchecked(arr.values(), arr.validity(), indices); - PrimitiveArray::new_unchecked(arr.data_type().clone(), values.into(), validity) + PrimitiveArray::new_unchecked(arr.dtype().clone(), values.into(), validity) } diff --git a/crates/polars-arrow/src/compute/take/structure.rs b/crates/polars-arrow/src/compute/take/structure.rs index 3619dae307bb..caad9f4ee0a4 100644 --- a/crates/polars-arrow/src/compute/take/structure.rs +++ b/crates/polars-arrow/src/compute/take/structure.rs @@ -30,5 +30,5 @@ pub(super) unsafe fn take_unchecked(array: &StructArray, indices: &IdxArr) -> St .validity() .map(|b| super::bitmap::take_bitmap_nulls_unchecked(b, indices)); let validity = combine_validities_and(validity.as_ref(), indices.validity()); - StructArray::new(array.data_type().clone(), values, validity) + StructArray::new(array.dtype().clone(), values, validity) } diff --git a/crates/polars-arrow/src/compute/temporal.rs b/crates/polars-arrow/src/compute/temporal.rs index 1198c04bb152..309493fbbbdb 100644 --- a/crates/polars-arrow/src/compute/temporal.rs +++ b/crates/polars-arrow/src/compute/temporal.rs @@ -51,20 +51,20 @@ impl Int8IsoWeek for chrono::DateTime {} // Macro to avoid repetition in functions, that apply // `chrono::Datelike` methods on Arrays macro_rules! date_like { - ($extract:ident, $array:ident, $data_type:path) => { - match $array.data_type().to_logical_type() { + ($extract:ident, $array:ident, $dtype:path) => { + match $array.dtype().to_logical_type() { ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, None) => { - date_variants($array, $data_type, |x| x.$extract().try_into().unwrap()) + date_variants($array, $dtype, |x| x.$extract().try_into().unwrap()) }, ArrowDataType::Timestamp(time_unit, Some(timezone_str)) => { let array = $array.as_any().downcast_ref().unwrap(); - if let Ok(timezone) = parse_offset(timezone_str) { + if let Ok(timezone) = parse_offset(timezone_str.as_str()) { Ok(extract_impl(array, *time_unit, timezone, |x| { x.$extract().try_into().unwrap() })) } else { - chrono_tz(array, *time_unit, timezone_str, |x| { + chrono_tz(array, *time_unit, timezone_str.as_str(), |x| { x.$extract().try_into().unwrap() }) } @@ -75,12 +75,14 @@ macro_rules! date_like { } /// Extracts the years of a temporal array as [`PrimitiveArray`]. +/// /// Use [`can_year`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn year(array: &dyn Array) -> PolarsResult> { date_like!(year, array, ArrowDataType::Int32) } /// Extracts the months of a temporal array as [`PrimitiveArray`]. +/// /// Value ranges from 1 to 12. /// Use [`can_month`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn month(array: &dyn Array) -> PolarsResult> { @@ -88,6 +90,7 @@ pub fn month(array: &dyn Array) -> PolarsResult> { } /// Extracts the days of a temporal array as [`PrimitiveArray`]. +/// /// Value ranges from 1 to 32 (Last day depends on month). /// Use [`can_day`] to check if this operation is supported for the target [`ArrowDataType`]. pub fn day(array: &dyn Array) -> PolarsResult> { @@ -95,13 +98,15 @@ pub fn day(array: &dyn Array) -> PolarsResult> { } /// Extracts weekday of a temporal array as [`PrimitiveArray`]. +/// /// Monday is 1, Tuesday is 2, ..., Sunday is 7. /// Use [`can_weekday`] to check if this operation is supported for the target [`ArrowDataType`] pub fn weekday(array: &dyn Array) -> PolarsResult> { date_like!(i8_weekday, array, ArrowDataType::Int8) } -/// Extracts ISO week of a temporal array as [`PrimitiveArray`] +/// Extracts ISO week of a temporal array as [`PrimitiveArray`]. +/// /// Value ranges from 1 to 53 (Last week depends on the year). /// Use [`can_iso_week`] to check if this operation is supported for the target [`ArrowDataType`] pub fn iso_week(array: &dyn Array) -> PolarsResult> { @@ -111,10 +116,10 @@ pub fn iso_week(array: &dyn Array) -> PolarsResult> { // Macro to avoid repetition in functions, that apply // `chrono::Timelike` methods on Arrays macro_rules! time_like { - ($extract:ident, $array:ident, $data_type:path) => { - match $array.data_type().to_logical_type() { + ($extract:ident, $array:ident, $dtype:path) => { + match $array.dtype().to_logical_type() { ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, None) => { - date_variants($array, $data_type, |x| x.$extract().try_into().unwrap()) + date_variants($array, $dtype, |x| x.$extract().try_into().unwrap()) }, ArrowDataType::Time32(_) | ArrowDataType::Time64(_) => { time_variants($array, ArrowDataType::UInt32, |x| { @@ -124,12 +129,12 @@ macro_rules! time_like { ArrowDataType::Timestamp(time_unit, Some(timezone_str)) => { let array = $array.as_any().downcast_ref().unwrap(); - if let Ok(timezone) = parse_offset(timezone_str) { + if let Ok(timezone) = parse_offset(timezone_str.as_str()) { Ok(extract_impl(array, *time_unit, timezone, |x| { x.$extract().try_into().unwrap() })) } else { - chrono_tz(array, *time_unit, timezone_str, |x| { + chrono_tz(array, *time_unit, timezone_str.as_str(), |x| { x.$extract().try_into().unwrap() }) } @@ -161,6 +166,7 @@ pub fn second(array: &dyn Array) -> PolarsResult> { } /// Extracts the nanoseconds of a temporal array as [`PrimitiveArray`]. +/// /// Value ranges from 0 to 1_999_999_999. /// The range from 1_000_000_000 to 1_999_999_999 represents the leap second. /// Use [`can_nanosecond`] to check if this operation is supported for the target [`ArrowDataType`]. @@ -170,27 +176,27 @@ pub fn nanosecond(array: &dyn Array) -> PolarsResult> { fn date_variants( array: &dyn Array, - data_type: ArrowDataType, + dtype: ArrowDataType, op: F, ) -> PolarsResult> where O: NativeType, F: Fn(chrono::NaiveDateTime) -> O, { - match array.data_type().to_logical_type() { + match array.dtype().to_logical_type() { ArrowDataType::Date32 => { let array = array .as_any() .downcast_ref::>() .unwrap(); - Ok(unary(array, |x| op(date32_to_datetime(x)), data_type)) + Ok(unary(array, |x| op(date32_to_datetime(x)), dtype)) }, ArrowDataType::Date64 => { let array = array .as_any() .downcast_ref::>() .unwrap(); - Ok(unary(array, |x| op(date64_to_datetime(x)), data_type)) + Ok(unary(array, |x| op(date64_to_datetime(x)), dtype)) }, ArrowDataType::Timestamp(time_unit, None) => { let array = array @@ -213,41 +219,41 @@ where fn time_variants( array: &dyn Array, - data_type: ArrowDataType, + dtype: ArrowDataType, op: F, ) -> PolarsResult> where O: NativeType, F: Fn(chrono::NaiveTime) -> O, { - match array.data_type().to_logical_type() { + match array.dtype().to_logical_type() { ArrowDataType::Time32(TimeUnit::Second) => { let array = array .as_any() .downcast_ref::>() .unwrap(); - Ok(unary(array, |x| op(time32s_to_time(x)), data_type)) + Ok(unary(array, |x| op(time32s_to_time(x)), dtype)) }, ArrowDataType::Time32(TimeUnit::Millisecond) => { let array = array .as_any() .downcast_ref::>() .unwrap(); - Ok(unary(array, |x| op(time32ms_to_time(x)), data_type)) + Ok(unary(array, |x| op(time32ms_to_time(x)), dtype)) }, ArrowDataType::Time64(TimeUnit::Microsecond) => { let array = array .as_any() .downcast_ref::>() .unwrap(); - Ok(unary(array, |x| op(time64us_to_time(x)), data_type)) + Ok(unary(array, |x| op(time64us_to_time(x)), dtype)) }, ArrowDataType::Time64(TimeUnit::Nanosecond) => { let array = array .as_any() .downcast_ref::>() .unwrap(); - Ok(unary(array, |x| op(time64ns_to_time(x)), data_type)) + Ok(unary(array, |x| op(time64ns_to_time(x)), dtype)) }, _ => unreachable!(), } @@ -350,33 +356,33 @@ where /// assert_eq!(can_year(&ArrowDataType::Date32), true); /// assert_eq!(can_year(&ArrowDataType::Int8), false); /// ``` -pub fn can_year(data_type: &ArrowDataType) -> bool { - can_date(data_type) +pub fn can_year(dtype: &ArrowDataType) -> bool { + can_date(dtype) } /// Checks if an array of type `datatype` can perform month operation -pub fn can_month(data_type: &ArrowDataType) -> bool { - can_date(data_type) +pub fn can_month(dtype: &ArrowDataType) -> bool { + can_date(dtype) } /// Checks if an array of type `datatype` can perform day operation -pub fn can_day(data_type: &ArrowDataType) -> bool { - can_date(data_type) +pub fn can_day(dtype: &ArrowDataType) -> bool { + can_date(dtype) } -/// Checks if an array of type `data_type` can perform weekday operation -pub fn can_weekday(data_type: &ArrowDataType) -> bool { - can_date(data_type) +/// Checks if an array of type `dtype` can perform weekday operation +pub fn can_weekday(dtype: &ArrowDataType) -> bool { + can_date(dtype) } -/// Checks if an array of type `data_type` can perform ISO week operation -pub fn can_iso_week(data_type: &ArrowDataType) -> bool { - can_date(data_type) +/// Checks if an array of type `dtype` can perform ISO week operation +pub fn can_iso_week(dtype: &ArrowDataType) -> bool { + can_date(dtype) } -fn can_date(data_type: &ArrowDataType) -> bool { +fn can_date(dtype: &ArrowDataType) -> bool { matches!( - data_type, + dtype, ArrowDataType::Date32 | ArrowDataType::Date64 | ArrowDataType::Timestamp(_, _) ) } @@ -391,28 +397,28 @@ fn can_date(data_type: &ArrowDataType) -> bool { /// assert_eq!(can_hour(&ArrowDataType::Time32(TimeUnit::Second)), true); /// assert_eq!(can_hour(&ArrowDataType::Int8), false); /// ``` -pub fn can_hour(data_type: &ArrowDataType) -> bool { - can_time(data_type) +pub fn can_hour(dtype: &ArrowDataType) -> bool { + can_time(dtype) } /// Checks if an array of type `datatype` can perform minute operation -pub fn can_minute(data_type: &ArrowDataType) -> bool { - can_time(data_type) +pub fn can_minute(dtype: &ArrowDataType) -> bool { + can_time(dtype) } /// Checks if an array of type `datatype` can perform second operation -pub fn can_second(data_type: &ArrowDataType) -> bool { - can_time(data_type) +pub fn can_second(dtype: &ArrowDataType) -> bool { + can_time(dtype) } /// Checks if an array of type `datatype` can perform nanosecond operation -pub fn can_nanosecond(data_type: &ArrowDataType) -> bool { - can_time(data_type) +pub fn can_nanosecond(dtype: &ArrowDataType) -> bool { + can_time(dtype) } -fn can_time(data_type: &ArrowDataType) -> bool { +fn can_time(dtype: &ArrowDataType) -> bool { matches!( - data_type, + dtype, ArrowDataType::Time32(TimeUnit::Second) | ArrowDataType::Time32(TimeUnit::Millisecond) | ArrowDataType::Time64(TimeUnit::Microsecond) diff --git a/crates/polars-arrow/src/datatypes/field.rs b/crates/polars-arrow/src/datatypes/field.rs index 950f081017c4..8bf18af82f46 100644 --- a/crates/polars-arrow/src/datatypes/field.rs +++ b/crates/polars-arrow/src/datatypes/field.rs @@ -1,3 +1,4 @@ +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -15,21 +16,28 @@ use super::{ArrowDataType, Metadata}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Field { /// Its name - pub name: String, + pub name: PlSmallStr, /// Its logical [`ArrowDataType`] - pub data_type: ArrowDataType, + pub dtype: ArrowDataType, /// Its nullability pub is_nullable: bool, /// Additional custom (opaque) metadata. pub metadata: Metadata, } +/// Support for `ArrowSchema::from_iter([field, ..])` +impl From for (PlSmallStr, Field) { + fn from(value: Field) -> Self { + (value.name.clone(), value) + } +} + impl Field { /// Creates a new [`Field`]. - pub fn new>(name: T, data_type: ArrowDataType, is_nullable: bool) -> Self { + pub fn new(name: PlSmallStr, dtype: ArrowDataType, is_nullable: bool) -> Self { Field { - name: name.into(), - data_type, + name, + dtype, is_nullable, metadata: Default::default(), } @@ -40,7 +48,7 @@ impl Field { pub fn with_metadata(self, metadata: Metadata) -> Self { Self { name: self.name, - data_type: self.data_type, + dtype: self.dtype, is_nullable: self.is_nullable, metadata, } @@ -48,16 +56,26 @@ impl Field { /// Returns the [`Field`]'s [`ArrowDataType`]. #[inline] - pub fn data_type(&self) -> &ArrowDataType { - &self.data_type + pub fn dtype(&self) -> &ArrowDataType { + &self.dtype } } #[cfg(feature = "arrow_rs")] impl From for arrow_schema::Field { fn from(value: Field) -> Self { - Self::new(value.name, value.data_type.into(), value.is_nullable) - .with_metadata(value.metadata.into_iter().collect()) + Self::new( + value.name.to_string(), + value.dtype.into(), + value.is_nullable, + ) + .with_metadata( + value + .metadata + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + ) } } @@ -71,13 +89,18 @@ impl From for Field { #[cfg(feature = "arrow_rs")] impl From<&arrow_schema::Field> for Field { fn from(value: &arrow_schema::Field) -> Self { - let data_type = value.data_type().clone().into(); + let dtype = value.data_type().clone().into(); let metadata = value .metadata() .iter() - .map(|(k, v)| (k.clone(), v.clone())) + .map(|(k, v)| (PlSmallStr::from_str(k), PlSmallStr::from_str(v))) .collect(); - Self::new(value.name(), data_type, value.is_nullable()).with_metadata(metadata) + Self::new( + PlSmallStr::from_str(value.name().as_str()), + dtype, + value.is_nullable(), + ) + .with_metadata(metadata) } } diff --git a/crates/polars-arrow/src/datatypes/mod.rs b/crates/polars-arrow/src/datatypes/mod.rs index c232c985d8a0..6ef9687f146e 100644 --- a/crates/polars-arrow/src/datatypes/mod.rs +++ b/crates/polars-arrow/src/datatypes/mod.rs @@ -9,14 +9,15 @@ use std::sync::Arc; pub use field::Field; pub use physical_type::*; +use polars_utils::pl_str::PlSmallStr; pub use schema::{ArrowSchema, ArrowSchemaRef}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// typedef for [BTreeMap] denoting [`Field`]'s and [`ArrowSchema`]'s metadata. -pub type Metadata = BTreeMap; -/// typedef for [Option<(String, Option)>] descr -pub(crate) type Extension = Option<(String, Option)>; +/// typedef for [BTreeMap] denoting [`Field`]'s and [`ArrowSchema`]'s metadata. +pub type Metadata = BTreeMap; +/// typedef for [Option<(PlSmallStr, Option)>] descr +pub(crate) type Extension = Option<(PlSmallStr, Option)>; /// The set of supported logical types in this crate. /// @@ -70,7 +71,7 @@ pub enum ArrowDataType { /// /// When the timezone is not specified, the timestamp is considered to have no timezone /// and is represented _as is_ - Timestamp(TimeUnit, Option), + Timestamp(TimeUnit, Option), /// An [`i32`] representing the elapsed time since UNIX epoch (1970-01-01) /// in days. Date32, @@ -163,7 +164,7 @@ pub enum ArrowDataType { /// - name /// - physical type /// - metadata - Extension(String, Box, Option), + Extension(PlSmallStr, Box, Option), /// A binary type that inlines small values /// and can intern bytes. BinaryView, @@ -193,7 +194,9 @@ impl From for arrow_schema::DataType { ArrowDataType::Float16 => Self::Float16, ArrowDataType::Float32 => Self::Float32, ArrowDataType::Float64 => Self::Float64, - ArrowDataType::Timestamp(unit, tz) => Self::Timestamp(unit.into(), tz.map(Into::into)), + ArrowDataType::Timestamp(unit, tz) => { + Self::Timestamp(unit.into(), tz.map(|x| Arc::::from(x.as_str()))) + }, ArrowDataType::Date32 => Self::Date32, ArrowDataType::Date64 => Self::Date64, ArrowDataType::Time32(unit) => Self::Time32(unit.into()), @@ -260,7 +263,7 @@ impl From for ArrowDataType { DataType::Float32 => Self::Float32, DataType::Float64 => Self::Float64, DataType::Timestamp(unit, tz) => { - Self::Timestamp(unit.into(), tz.map(|x| x.to_string())) + Self::Timestamp(unit.into(), tz.map(|x| PlSmallStr::from_str(x.as_ref()))) }, DataType::Date32 => Self::Date32, DataType::Date64 => Self::Date64, @@ -494,16 +497,16 @@ impl ArrowDataType { Interval(IntervalUnit::MonthDayNano) => unimplemented!(), Binary => Binary, List(field) => List(Box::new(Field { - data_type: field.data_type.underlying_physical_type(), + dtype: field.dtype.underlying_physical_type(), ..*field.clone() })), LargeList(field) => LargeList(Box::new(Field { - data_type: field.data_type.underlying_physical_type(), + dtype: field.dtype.underlying_physical_type(), ..*field.clone() })), FixedSizeList(field, width) => FixedSizeList( Box::new(Field { - data_type: field.data_type.underlying_physical_type(), + dtype: field.dtype.underlying_physical_type(), ..*field.clone() }), *width, @@ -512,7 +515,7 @@ impl ArrowDataType { fields .iter() .map(|field| Field { - data_type: field.data_type.underlying_physical_type(), + dtype: field.dtype.underlying_physical_type(), ..field.clone() }) .collect(), @@ -538,13 +541,29 @@ impl ArrowDataType { pub fn inner_dtype(&self) -> Option<&ArrowDataType> { match self { - ArrowDataType::List(inner) => Some(inner.data_type()), - ArrowDataType::LargeList(inner) => Some(inner.data_type()), - ArrowDataType::FixedSizeList(inner, _) => Some(inner.data_type()), + ArrowDataType::List(inner) => Some(inner.dtype()), + ArrowDataType::LargeList(inner) => Some(inner.dtype()), + ArrowDataType::FixedSizeList(inner, _) => Some(inner.dtype()), _ => None, } } + pub fn is_nested(&self) -> bool { + use ArrowDataType as D; + + matches!( + self, + D::List(_) + | D::LargeList(_) + | D::FixedSizeList(_, _) + | D::Struct(_) + | D::Union(_, _, _) + | D::Map(_, _) + | D::Dictionary(_, _, _) + | D::Extension(_, _, _) + ) + } + pub fn is_view(&self) -> bool { matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView) } @@ -593,8 +612,10 @@ pub type SchemaRef = Arc; /// support get extension for metadata pub fn get_extension(metadata: &Metadata) -> Extension { - if let Some(name) = metadata.get("ARROW:extension:name") { - let metadata = metadata.get("ARROW:extension:metadata").cloned(); + if let Some(name) = metadata.get(&PlSmallStr::from_static("ARROW:extension:name")) { + let metadata = metadata + .get(&PlSmallStr::from_static("ARROW:extension:metadata")) + .cloned(); Some((name.clone(), metadata)) } else { None diff --git a/crates/polars-arrow/src/datatypes/physical_type.rs b/crates/polars-arrow/src/datatypes/physical_type.rs index 31693cefd4bd..174c0401ca3f 100644 --- a/crates/polars-arrow/src/datatypes/physical_type.rs +++ b/crates/polars-arrow/src/datatypes/physical_type.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; pub use crate::types::PrimitiveType; /// The set of physical types: unique in-memory representations of an Arrow array. +/// /// A physical type has a one-to-many relationship with a [`crate::datatypes::ArrowDataType`] and /// a one-to-one mapping to each struct in this crate that implements [`crate::array::Array`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] diff --git a/crates/polars-arrow/src/datatypes/schema.rs b/crates/polars-arrow/src/datatypes/schema.rs index 9b01816c1135..02920204b4dc 100644 --- a/crates/polars-arrow/src/datatypes/schema.rs +++ b/crates/polars-arrow/src/datatypes/schema.rs @@ -1,93 +1,11 @@ use std::sync::Arc; -use polars_error::{polars_bail, PolarsResult}; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; +use super::Field; -use super::{Field, Metadata}; - -/// An ordered sequence of [`Field`]s with associated [`Metadata`]. +/// An ordered sequence of [`Field`]s /// /// [`ArrowSchema`] is an abstraction used to read from, and write to, Arrow IPC format, /// Apache Parquet, and Apache Avro. All these formats have a concept of a schema /// with fields and metadata. -#[derive(Debug, Clone, PartialEq, Eq, Default)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct ArrowSchema { - /// The fields composing this schema. - pub fields: Vec, - /// Optional metadata. - pub metadata: Metadata, -} - +pub type ArrowSchema = polars_schema::Schema; pub type ArrowSchemaRef = Arc; - -impl ArrowSchema { - /// Attaches a [`Metadata`] to [`ArrowSchema`] - #[inline] - pub fn with_metadata(self, metadata: Metadata) -> Self { - Self { - fields: self.fields, - metadata, - } - } - - #[inline] - pub fn len(&self) -> usize { - self.fields.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.fields.is_empty() - } - - /// Returns a new [`ArrowSchema`] with a subset of all fields whose `predicate` - /// evaluates to true. - pub fn filter bool>(self, predicate: F) -> Self { - let fields = self - .fields - .into_iter() - .enumerate() - .filter_map(|(index, f)| { - if (predicate)(index, &f) { - Some(f) - } else { - None - } - }) - .collect(); - - ArrowSchema { - fields, - metadata: self.metadata, - } - } - - pub fn try_project(&self, indices: &[usize]) -> PolarsResult { - let fields = indices.iter().map(|&i| { - let Some(out) = self.fields.get(i) else { - polars_bail!( - SchemaFieldNotFound: "projection index {} is out of bounds for schema of length {}", - i, self.fields.len() - ); - }; - - Ok(out.clone()) - }).collect::>>()?; - - Ok(ArrowSchema { - fields, - metadata: self.metadata.clone(), - }) - } -} - -impl From> for ArrowSchema { - fn from(fields: Vec) -> Self { - Self { - fields, - ..Default::default() - } - } -} diff --git a/crates/polars-arrow/src/ffi/array.rs b/crates/polars-arrow/src/ffi/array.rs index 34abe43704ed..55090f1c760a 100644 --- a/crates/polars-arrow/src/ffi/array.rs +++ b/crates/polars-arrow/src/ffi/array.rs @@ -19,7 +19,7 @@ use crate::{match_integer_type, with_match_primitive_type_full}; /// * the interface is not valid (e.g. a null pointer) pub unsafe fn try_from(array: A) -> PolarsResult> { use PhysicalType::*; - Ok(match array.data_type().to_physical_type() { + Ok(match array.dtype().to_physical_type() { Null => Box::new(NullArray::try_from_ffi(array)?), Boolean => Box::new(BooleanArray::try_from_ffi(array)?), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { @@ -99,7 +99,7 @@ impl ArrowArray { /// releasing this struct, or contents in `buffers` leak. pub(crate) fn new(array: Box) -> Self { let needs_variadic_buffer_sizes = matches!( - array.data_type(), + array.dtype(), ArrowDataType::BinaryView | ArrowDataType::Utf8View ); @@ -207,12 +207,12 @@ impl ArrowArray { /// The caller must ensure that the buffer at index `i` is not mutably shared. unsafe fn get_buffer_ptr( array: &ArrowArray, - data_type: &ArrowDataType, + dtype: &ArrowDataType, index: usize, ) -> PolarsResult<*mut T> { if array.buffers.is_null() { polars_bail!( ComputeError: - "an ArrowArray of type {data_type:?} must have non-null buffers" + "an ArrowArray of type {dtype:?} must have non-null buffers" ); } @@ -222,7 +222,7 @@ unsafe fn get_buffer_ptr( != 0 { polars_bail!( ComputeError: - "an ArrowArray of type {data_type:?} + "an ArrowArray of type {dtype:?} must have buffer {index} aligned to type {}", std::any::type_name::<*mut *const u8>() ); @@ -231,7 +231,7 @@ unsafe fn get_buffer_ptr( if index >= array.n_buffers as usize { polars_bail!(ComputeError: - "An ArrowArray of type {data_type:?} + "An ArrowArray of type {dtype:?} must have buffer {index}." ) } @@ -239,7 +239,7 @@ unsafe fn get_buffer_ptr( let ptr = *buffers.add(index); if ptr.is_null() { polars_bail!(ComputeError: - "An array of type {data_type:?} + "An array of type {dtype:?} must have a non-null buffer {index}" ) } @@ -250,7 +250,7 @@ unsafe fn get_buffer_ptr( unsafe fn create_buffer_known_len( array: &ArrowArray, - data_type: &ArrowDataType, + dtype: &ArrowDataType, owner: InternalArrowArray, len: usize, index: usize, @@ -258,7 +258,7 @@ unsafe fn create_buffer_known_len( if len == 0 { return Ok(Buffer::new()); } - let ptr: *mut T = get_buffer_ptr(array, data_type, index)?; + let ptr: *mut T = get_buffer_ptr(array, dtype, index)?; let bytes = Bytes::from_foreign(ptr, len, BytesAllocator::InternalArrowArray(owner)); Ok(Buffer::from_bytes(bytes)) } @@ -270,18 +270,18 @@ unsafe fn create_buffer_known_len( /// * the buffers' pointers are not mutably shared for the lifetime of `owner` unsafe fn create_buffer( array: &ArrowArray, - data_type: &ArrowDataType, + dtype: &ArrowDataType, owner: InternalArrowArray, index: usize, ) -> PolarsResult> { - let len = buffer_len(array, data_type, index)?; + let len = buffer_len(array, dtype, index)?; if len == 0 { return Ok(Buffer::new()); } - let offset = buffer_offset(array, data_type, index); - let ptr: *mut T = get_buffer_ptr(array, data_type, index)?; + let offset = buffer_offset(array, dtype, index); + let ptr: *mut T = get_buffer_ptr(array, dtype, index)?; // We have to check alignment. // This is the zero-copy path. @@ -304,7 +304,7 @@ unsafe fn create_buffer( /// * the buffers' pointer is not mutable for the lifetime of `owner` unsafe fn create_bitmap( array: &ArrowArray, - data_type: &ArrowDataType, + dtype: &ArrowDataType, owner: InternalArrowArray, index: usize, // if this is the validity bitmap @@ -315,7 +315,7 @@ unsafe fn create_bitmap( if len == 0 { return Ok(Bitmap::new()); } - let ptr = get_buffer_ptr(array, data_type, index)?; + let ptr = get_buffer_ptr(array, dtype, index)?; // Pointer of u8 has alignment 1, so we don't have to check alignment. @@ -336,12 +336,12 @@ unsafe fn create_bitmap( )) } -fn buffer_offset(array: &ArrowArray, data_type: &ArrowDataType, i: usize) -> usize { +fn buffer_offset(array: &ArrowArray, dtype: &ArrowDataType, i: usize) -> usize { use PhysicalType::*; - match (data_type.to_physical_type(), i) { + match (dtype.to_physical_type(), i) { (LargeUtf8, 2) | (LargeBinary, 2) | (Utf8, 2) | (Binary, 2) => 0, (FixedSizeBinary, 1) => { - if let ArrowDataType::FixedSizeBinary(size) = data_type.to_logical_type() { + if let ArrowDataType::FixedSizeBinary(size) = dtype.to_logical_type() { let offset: usize = array.offset.try_into().expect("Offset to fit in `usize`"); offset * *size } else { @@ -353,21 +353,17 @@ fn buffer_offset(array: &ArrowArray, data_type: &ArrowDataType, i: usize) -> usi } /// Returns the length, in slots, of the buffer `i` (indexed according to the C data interface) -unsafe fn buffer_len( - array: &ArrowArray, - data_type: &ArrowDataType, - i: usize, -) -> PolarsResult { - Ok(match (data_type.to_physical_type(), i) { +unsafe fn buffer_len(array: &ArrowArray, dtype: &ArrowDataType, i: usize) -> PolarsResult { + Ok(match (dtype.to_physical_type(), i) { (PhysicalType::FixedSizeBinary, 1) => { - if let ArrowDataType::FixedSizeBinary(size) = data_type.to_logical_type() { + if let ArrowDataType::FixedSizeBinary(size) = dtype.to_logical_type() { *size * (array.offset as usize + array.length as usize) } else { unreachable!() } }, (PhysicalType::FixedSizeList, 1) => { - if let ArrowDataType::FixedSizeList(_, size) = data_type.to_logical_type() { + if let ArrowDataType::FixedSizeList(_, size) = dtype.to_logical_type() { *size * (array.offset as usize + array.length as usize) } else { unreachable!() @@ -388,7 +384,7 @@ unsafe fn buffer_len( }, (PhysicalType::Utf8, 2) | (PhysicalType::Binary, 2) => { // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) - let len = buffer_len(array, data_type, 1)?; + let len = buffer_len(array, dtype, 1)?; // first buffer is the null buffer => add(1) let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; // interpret as i32 @@ -399,7 +395,7 @@ unsafe fn buffer_len( }, (PhysicalType::LargeUtf8, 2) | (PhysicalType::LargeBinary, 2) => { // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) - let len = buffer_len(array, data_type, 1)?; + let len = buffer_len(array, dtype, 1)?; // first buffer is the null buffer => add(1) let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; // interpret as i64 @@ -421,20 +417,20 @@ unsafe fn buffer_len( /// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` unsafe fn create_child( array: &ArrowArray, - data_type: &ArrowDataType, + dtype: &ArrowDataType, parent: InternalArrowArray, index: usize, ) -> PolarsResult> { - let data_type = get_child(data_type, index)?; + let dtype = get_child(dtype, index)?; // catch what we can if array.children.is_null() { - polars_bail!(ComputeError: "an ArrowArray of type {data_type:?} must have non-null children"); + polars_bail!(ComputeError: "an ArrowArray of type {dtype:?} must have non-null children"); } if index >= array.n_children as usize { polars_bail!(ComputeError: - "an ArrowArray of type {data_type:?} + "an ArrowArray of type {dtype:?} must have child {index}." ); } @@ -445,14 +441,14 @@ unsafe fn create_child( // catch what we can if arr_ptr.is_null() { polars_bail!(ComputeError: - "an array of type {data_type:?} + "an array of type {dtype:?} must have a non-null child {index}" ) } // SAFETY: invariant of this function let arr_ptr = unsafe { &*arr_ptr }; - Ok(ArrowArrayChild::new(arr_ptr, data_type, parent)) + Ok(ArrowArrayChild::new(arr_ptr, dtype, parent)) } /// # Safety @@ -462,22 +458,22 @@ unsafe fn create_child( /// * `array.dictionary` is not mutably shared for the lifetime of `parent` unsafe fn create_dictionary( array: &ArrowArray, - data_type: &ArrowDataType, + dtype: &ArrowDataType, parent: InternalArrowArray, ) -> PolarsResult>> { - if let ArrowDataType::Dictionary(_, values, _) = data_type { - let data_type = values.as_ref().clone(); + if let ArrowDataType::Dictionary(_, values, _) = dtype { + let dtype = values.as_ref().clone(); // catch what we can if array.dictionary.is_null() { polars_bail!(ComputeError: - "an array of type {data_type:?} + "an array of type {dtype:?} must have a non-null dictionary" ) } // SAFETY: part of the invariant let array = unsafe { &*array.dictionary }; - Ok(Some(ArrowArrayChild::new(array, data_type, parent))) + Ok(Some(ArrowArrayChild::new(array, dtype, parent))) } else { Ok(None) } @@ -499,7 +495,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { if self.array().null_count() == 0 { Ok(None) } else { - create_bitmap(self.array(), self.data_type(), self.owner(), 0, true).map(Some) + create_bitmap(self.array(), self.dtype(), self.owner(), 0, true).map(Some) } } @@ -507,7 +503,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { /// The caller must guarantee that the buffer `index` corresponds to a buffer. /// This function assumes that the buffer created from FFI is valid; this is impossible to prove. unsafe fn buffer(&self, index: usize) -> PolarsResult> { - create_buffer::(self.array(), self.data_type(), self.owner(), index) + create_buffer::(self.array(), self.dtype(), self.owner(), index) } /// # Safety @@ -518,7 +514,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { index: usize, len: usize, ) -> PolarsResult> { - create_buffer_known_len::(self.array(), self.data_type(), self.owner(), len, index) + create_buffer_known_len::(self.array(), self.dtype(), self.owner(), len, index) } /// # Safety @@ -526,7 +522,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { /// * the buffer at position `index` is valid for the declared length /// * the buffers' pointer is not mutable for the lifetime of `owner` unsafe fn bitmap(&self, index: usize) -> PolarsResult { - create_bitmap(self.array(), self.data_type(), self.owner(), index, false) + create_bitmap(self.array(), self.dtype(), self.owner(), index, false) } /// # Safety @@ -535,11 +531,11 @@ pub trait ArrowArrayRef: std::fmt::Debug { /// * the pointer of `array.children` at `index` is valid /// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` unsafe fn child(&self, index: usize) -> PolarsResult { - create_child(self.array(), self.data_type(), self.parent().clone(), index) + create_child(self.array(), self.dtype(), self.parent().clone(), index) } unsafe fn dictionary(&self) -> PolarsResult> { - create_dictionary(self.array(), self.data_type(), self.parent().clone()) + create_dictionary(self.array(), self.dtype(), self.parent().clone()) } fn n_buffers(&self) -> usize; @@ -549,7 +545,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { fn parent(&self) -> &InternalArrowArray; fn array(&self) -> &ArrowArray; - fn data_type(&self) -> &ArrowDataType; + fn dtype(&self) -> &ArrowDataType; } /// Struct used to move an Array from and to the C Data Interface. @@ -576,22 +572,22 @@ pub struct InternalArrowArray { // Arc is used for sharability since this is immutable array: Arc, // Arced to reduce cost of cloning - data_type: Arc, + dtype: Arc, } impl InternalArrowArray { - pub fn new(array: ArrowArray, data_type: ArrowDataType) -> Self { + pub fn new(array: ArrowArray, dtype: ArrowDataType) -> Self { Self { array: Arc::new(array), - data_type: Arc::new(data_type), + dtype: Arc::new(dtype), } } } impl ArrowArrayRef for InternalArrowArray { - /// the data_type as declared in the schema - fn data_type(&self) -> &ArrowDataType { - &self.data_type + /// the dtype as declared in the schema + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn parent(&self) -> &InternalArrowArray { @@ -618,14 +614,14 @@ impl ArrowArrayRef for InternalArrowArray { #[derive(Debug)] pub struct ArrowArrayChild<'a> { array: &'a ArrowArray, - data_type: ArrowDataType, + dtype: ArrowDataType, parent: InternalArrowArray, } impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { - /// the data_type as declared in the schema - fn data_type(&self) -> &ArrowDataType { - &self.data_type + /// the dtype as declared in the schema + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn parent(&self) -> &InternalArrowArray { @@ -650,10 +646,10 @@ impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { } impl<'a> ArrowArrayChild<'a> { - fn new(array: &'a ArrowArray, data_type: ArrowDataType, parent: InternalArrowArray) -> Self { + fn new(array: &'a ArrowArray, dtype: ArrowDataType, parent: InternalArrowArray) -> Self { Self { array, - data_type, + dtype, parent, } } diff --git a/crates/polars-arrow/src/ffi/bridge.rs b/crates/polars-arrow/src/ffi/bridge.rs index 7c45ad2faa12..c23c21643214 100644 --- a/crates/polars-arrow/src/ffi/bridge.rs +++ b/crates/polars-arrow/src/ffi/bridge.rs @@ -14,7 +14,7 @@ macro_rules! ffi_dyn { pub fn align_to_c_data_interface(array: Box) -> Box { use crate::datatypes::PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => ffi_dyn!(array, NullArray), Boolean => ffi_dyn!(array, BooleanArray), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { diff --git a/crates/polars-arrow/src/ffi/mod.rs b/crates/polars-arrow/src/ffi/mod.rs index 7308a3b8a59e..b7cf2b957b0a 100644 --- a/crates/polars-arrow/src/ffi/mod.rs +++ b/crates/polars-arrow/src/ffi/mod.rs @@ -40,7 +40,7 @@ pub unsafe fn import_field_from_c(field: &ArrowSchema) -> PolarsResult { /// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). pub unsafe fn import_array_from_c( array: ArrowArray, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PolarsResult> { - try_from(InternalArrowArray::new(array, data_type)) + try_from(InternalArrowArray::new(array, dtype)) } diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs index f958311d7988..d6d89aaaa438 100644 --- a/crates/polars-arrow/src/ffi/schema.rs +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -3,6 +3,7 @@ use std::ffi::{CStr, CString}; use std::ptr; use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use super::ArrowSchema; use crate::datatypes::{ @@ -38,8 +39,8 @@ unsafe extern "C" fn c_release_schema(schema: *mut ArrowSchema) { } /// allocate (and hold) the children -fn schema_children(data_type: &ArrowDataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> { - match data_type { +fn schema_children(dtype: &ArrowDataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> { + match dtype { ArrowDataType::List(field) | ArrowDataType::FixedSizeList(field, _) | ArrowDataType::LargeList(field) => { @@ -61,20 +62,19 @@ fn schema_children(data_type: &ArrowDataType, flags: &mut i64) -> Box<[*mut Arro impl ArrowSchema { /// creates a new [ArrowSchema] pub(crate) fn new(field: &Field) -> Self { - let format = to_format(field.data_type()); + let format = to_format(field.dtype()); let name = field.name.clone(); let mut flags = field.is_nullable as i64 * 2; // note: this cannot be done along with the above because the above is fallible and this op leaks. - let children_ptr = schema_children(field.data_type(), &mut flags); + let children_ptr = schema_children(field.dtype(), &mut flags); let n_children = children_ptr.len() as i64; - let dictionary = if let ArrowDataType::Dictionary(_, values, is_ordered) = field.data_type() - { + let dictionary = if let ArrowDataType::Dictionary(_, values, is_ordered) = field.dtype() { flags += *is_ordered as i64; // we do not store field info in the dict values, so can't recover it all :( - let field = Field::new("", values.as_ref().clone(), true); + let field = Field::new(PlSmallStr::EMPTY, values.as_ref().clone(), true); Some(Box::new(ArrowSchema::new(&field))) } else { None @@ -82,29 +82,32 @@ impl ArrowSchema { let metadata = &field.metadata; - let metadata = - if let ArrowDataType::Extension(name, _, extension_metadata) = field.data_type() { - // append extension information. - let mut metadata = metadata.clone(); - - // metadata - if let Some(extension_metadata) = extension_metadata { - metadata.insert( - "ARROW:extension:metadata".to_string(), - extension_metadata.clone(), - ); - } + let metadata = if let ArrowDataType::Extension(name, _, extension_metadata) = field.dtype() + { + // append extension information. + let mut metadata = metadata.clone(); + + // metadata + if let Some(extension_metadata) = extension_metadata { + metadata.insert( + PlSmallStr::from_static("ARROW:extension:metadata"), + extension_metadata.clone(), + ); + } - metadata.insert("ARROW:extension:name".to_string(), name.clone()); + metadata.insert( + PlSmallStr::from_static("ARROW:extension:name"), + name.clone(), + ); - Some(metadata_to_bytes(&metadata)) - } else if !metadata.is_empty() { - Some(metadata_to_bytes(metadata)) - } else { - None - }; + Some(metadata_to_bytes(&metadata)) + } else if !metadata.is_empty() { + Some(metadata_to_bytes(metadata)) + } else { + None + }; - let name = CString::new(name).unwrap(); + let name = CString::new(name.as_bytes()).unwrap(); let format = CString::new(format).unwrap(); let mut private = Box::new(SchemaPrivateData { @@ -200,23 +203,28 @@ impl Drop for ArrowSchema { pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> PolarsResult { let dictionary = schema.dictionary(); - let data_type = if let Some(dictionary) = dictionary { + let dtype = if let Some(dictionary) = dictionary { let indices = to_integer_type(schema.format())?; let values = to_field(dictionary)?; let is_ordered = schema.flags & 1 == 1; - ArrowDataType::Dictionary(indices, Box::new(values.data_type().clone()), is_ordered) + ArrowDataType::Dictionary(indices, Box::new(values.dtype().clone()), is_ordered) } else { - to_data_type(schema)? + to_dtype(schema)? }; let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) }; - let data_type = if let Some((name, extension_metadata)) = extension { - ArrowDataType::Extension(name, Box::new(data_type), extension_metadata) + let dtype = if let Some((name, extension_metadata)) = extension { + ArrowDataType::Extension(name, Box::new(dtype), extension_metadata) } else { - data_type + dtype }; - Ok(Field::new(schema.name(), data_type, schema.nullable()).with_metadata(metadata)) + Ok(Field::new( + PlSmallStr::from_str(schema.name()), + dtype, + schema.nullable(), + ) + .with_metadata(metadata)) } fn to_integer_type(format: &str) -> PolarsResult { @@ -239,7 +247,7 @@ fn to_integer_type(format: &str) -> PolarsResult { }) } -unsafe fn to_data_type(schema: &ArrowSchema) -> PolarsResult { +unsafe fn to_dtype(schema: &ArrowSchema) -> PolarsResult { Ok(match schema.format() { "n" => ArrowDataType::Null, "b" => ArrowDataType::Boolean, @@ -301,14 +309,18 @@ unsafe fn to_data_type(schema: &ArrowSchema) -> PolarsResult { ["tsn", ""] => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), // Timestamps with timezone - ["tss", tz] => ArrowDataType::Timestamp(TimeUnit::Second, Some(tz.to_string())), + ["tss", tz] => { + ArrowDataType::Timestamp(TimeUnit::Second, Some(PlSmallStr::from_str(tz))) + }, ["tsm", tz] => { - ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) + ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(PlSmallStr::from_str(tz))) }, ["tsu", tz] => { - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(PlSmallStr::from_str(tz))) + }, + ["tsn", tz] => { + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(PlSmallStr::from_str(tz))) }, - ["tsn", tz] => ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())), ["w", size_raw] => { // Example: "w:42" fixed-width binary [42 bytes] @@ -401,8 +413,8 @@ unsafe fn to_data_type(schema: &ArrowSchema) -> PolarsResult { } /// the inverse of [to_field] -fn to_format(data_type: &ArrowDataType) -> String { - match data_type { +fn to_format(dtype: &ArrowDataType) -> String { + match dtype { ArrowDataType::Null => "n".to_string(), ArrowDataType::Boolean => "b".to_string(), ArrowDataType::Int8 => "c".to_string(), @@ -451,7 +463,7 @@ fn to_format(data_type: &ArrowDataType) -> String { format!( "ts{}:{}", unit, - tz.as_ref().map(|x| x.as_ref()).unwrap_or("") + tz.as_ref().map(|x| x.as_str()).unwrap_or("") ) }, ArrowDataType::Utf8View => "vu".to_string(), @@ -468,9 +480,9 @@ fn to_format(data_type: &ArrowDataType) -> String { let mut r = format!("+u{sparsness}:"); let ids = if let Some(ids) = ids { ids.iter() - .fold(String::new(), |a, b| a + &b.to_string() + ",") + .fold(String::new(), |a, b| a + b.to_string().as_str() + ",") } else { - (0..f.len()).fold(String::new(), |a, b| a + &b.to_string() + ",") + (0..f.len()).fold(String::new(), |a, b| a + b.to_string().as_str() + ",") }; let ids = &ids[..ids.len() - 1]; // take away last "," r.push_str(ids); @@ -483,22 +495,22 @@ fn to_format(data_type: &ArrowDataType) -> String { } } -pub(super) fn get_child(data_type: &ArrowDataType, index: usize) -> PolarsResult { - match (index, data_type) { - (0, ArrowDataType::List(field)) => Ok(field.data_type().clone()), - (0, ArrowDataType::FixedSizeList(field, _)) => Ok(field.data_type().clone()), - (0, ArrowDataType::LargeList(field)) => Ok(field.data_type().clone()), - (0, ArrowDataType::Map(field, _)) => Ok(field.data_type().clone()), - (index, ArrowDataType::Struct(fields)) => Ok(fields[index].data_type().clone()), - (index, ArrowDataType::Union(fields, _, _)) => Ok(fields[index].data_type().clone()), +pub(super) fn get_child(dtype: &ArrowDataType, index: usize) -> PolarsResult { + match (index, dtype) { + (0, ArrowDataType::List(field)) => Ok(field.dtype().clone()), + (0, ArrowDataType::FixedSizeList(field, _)) => Ok(field.dtype().clone()), + (0, ArrowDataType::LargeList(field)) => Ok(field.dtype().clone()), + (0, ArrowDataType::Map(field, _)) => Ok(field.dtype().clone()), + (index, ArrowDataType::Struct(fields)) => Ok(fields[index].dtype().clone()), + (index, ArrowDataType::Union(fields, _, _)) => Ok(fields[index].dtype().clone()), (index, ArrowDataType::Extension(_, subtype, _)) => get_child(subtype, index), - (child, data_type) => polars_bail!(ComputeError: - "Requested child {child} to type {data_type:?} that has no such child", + (child, dtype) => polars_bail!(ComputeError: + "Requested child {child} to type {dtype:?} that has no such child", ), } } -fn metadata_to_bytes(metadata: &BTreeMap) -> Vec { +fn metadata_to_bytes(metadata: &BTreeMap) -> Vec { let a = (metadata.len() as i32).to_ne_bytes().to_vec(); metadata.iter().fold(a, |mut acc, (key, value)| { acc.extend((key.len() as i32).to_ne_bytes()); @@ -541,13 +553,13 @@ unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, data = data.add(value_len); match key { "ARROW:extension:name" => { - extension_name = Some(value.to_string()); + extension_name = Some(PlSmallStr::from_str(value)); }, "ARROW:extension:metadata" => { - extension_metadata = Some(value.to_string()); + extension_metadata = Some(PlSmallStr::from_str(value)); }, _ => { - result.insert(key.to_string(), value.to_string()); + result.insert(PlSmallStr::from_str(key), PlSmallStr::from_str(value)); }, }; } @@ -587,35 +599,50 @@ mod tests { ArrowDataType::LargeBinary, ArrowDataType::FixedSizeBinary(2), ArrowDataType::List(Box::new(Field::new( - "example", + PlSmallStr::from_static("example"), ArrowDataType::Boolean, false, ))), ArrowDataType::FixedSizeList( - Box::new(Field::new("example", ArrowDataType::Boolean, false)), + Box::new(Field::new( + PlSmallStr::from_static("example"), + ArrowDataType::Boolean, + false, + )), 2, ), ArrowDataType::LargeList(Box::new(Field::new( - "example", + PlSmallStr::from_static("example"), ArrowDataType::Boolean, false, ))), ArrowDataType::Struct(vec![ - Field::new("a", ArrowDataType::Int64, true), + Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true), Field::new( - "b", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + PlSmallStr::from_static("b"), + ArrowDataType::List(Box::new(Field::new( + PlSmallStr::from_static("item"), + ArrowDataType::Int32, + true, + ))), true, ), ]), - ArrowDataType::Map(Box::new(Field::new("a", ArrowDataType::Int64, true)), true), + ArrowDataType::Map( + Box::new(Field::new( + PlSmallStr::from_static("a"), + ArrowDataType::Int64, + true, + )), + true, + ), ArrowDataType::Union( vec![ - Field::new("a", ArrowDataType::Int64, true), + Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true), Field::new( - "b", + PlSmallStr::from_static("b"), ArrowDataType::List(Box::new(Field::new( - "item", + PlSmallStr::from_static("item"), ArrowDataType::Int32, true, ))), @@ -627,11 +654,11 @@ mod tests { ), ArrowDataType::Union( vec![ - Field::new("a", ArrowDataType::Int64, true), + Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true), Field::new( - "b", + PlSmallStr::from_static("b"), ArrowDataType::List(Box::new(Field::new( - "item", + PlSmallStr::from_static("item"), ArrowDataType::Int32, true, ))), @@ -651,7 +678,7 @@ mod tests { dts.push(ArrowDataType::Timestamp(time_unit, None)); dts.push(ArrowDataType::Timestamp( time_unit, - Some("00:00".to_string()), + Some(PlSmallStr::from_static("00:00")), )); dts.push(ArrowDataType::Duration(time_unit)); } @@ -664,9 +691,9 @@ mod tests { } for expected in dts { - let field = Field::new("a", expected.clone(), true); + let field = Field::new(PlSmallStr::from_static("a"), expected.clone(), true); let schema = ArrowSchema::new(&field); - let result = unsafe { super::to_data_type(&schema).unwrap() }; + let result = unsafe { super::to_dtype(&schema).unwrap() }; assert_eq!(result, expected); } } diff --git a/crates/polars-arrow/src/ffi/stream.rs b/crates/polars-arrow/src/ffi/stream.rs index b894bc6748ab..2666d417ec48 100644 --- a/crates/polars-arrow/src/ffi/stream.rs +++ b/crates/polars-arrow/src/ffi/stream.rs @@ -120,7 +120,7 @@ impl> ArrowArrayStreamReader { array.release?; // SAFETY: assumed from the C stream interface - unsafe { import_array_from_c(array, self.field.data_type.clone()) } + unsafe { import_array_from_c(array, self.field.dtype.clone()) } .map(Some) .transpose() } @@ -140,9 +140,9 @@ unsafe extern "C" fn get_next(iter: *mut ArrowArrayStream, array: *mut ArrowArra match private.iter.next() { Some(Ok(item)) => { - // check that the array has the same data_type as field - let item_dt = item.data_type(); - let expected_dt = private.field.data_type(); + // check that the array has the same dtype as field + let item_dt = item.dtype(); + let expected_dt = private.field.dtype(); if item_dt != expected_dt { private.error = Some(CString::new(format!("The iterator produced an item of data type {item_dt:?} but the producer expects data type {expected_dt:?}").as_bytes().to_vec()).unwrap()); return 2001; // custom application specific error (since this is never a result of this interface) diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs index 36297a37621d..f9423f83305a 100644 --- a/crates/polars-arrow/src/io/avro/read/deserialize.rs +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -11,16 +11,16 @@ use crate::types::months_days_ns; use crate::with_match_primitive_type_full; fn make_mutable( - data_type: &ArrowDataType, + dtype: &ArrowDataType, avro_field: Option<&AvroSchema>, capacity: usize, ) -> PolarsResult> { - Ok(match data_type.to_physical_type() { + Ok(match dtype.to_physical_type() { PhysicalType::Boolean => { Box::new(MutableBooleanArray::with_capacity(capacity)) as Box }, PhysicalType::Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { - Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(dtype.clone())) as Box }), PhysicalType::Binary => { @@ -38,12 +38,12 @@ fn make_mutable( unreachable!() } }, - _ => match data_type { + _ => match dtype { ArrowDataType::List(inner) => { - let values = make_mutable(inner.data_type(), None, 0)?; + let values = make_mutable(inner.dtype(), None, 0)?; Box::new(DynMutableListArray::::new_from( values, - data_type.clone(), + dtype.clone(), capacity, )) as Box }, @@ -54,10 +54,9 @@ fn make_mutable( ArrowDataType::Struct(fields) => { let values = fields .iter() - .map(|field| make_mutable(field.data_type(), None, capacity)) + .map(|field| make_mutable(field.dtype(), None, capacity)) .collect::>>()?; - Box::new(DynMutableStructArray::new(values, data_type.clone())) - as Box + Box::new(DynMutableStructArray::new(values, dtype.clone())) as Box }, other => { polars_bail!(nyi = "Deserializing type {other:#?} is still not implemented") @@ -96,8 +95,8 @@ fn deserialize_value<'a>( avro_field: &AvroSchema, mut block: &'a [u8], ) -> PolarsResult<&'a [u8]> { - let data_type = array.data_type(); - match data_type { + let dtype = array.dtype(); + match dtype { ArrowDataType::List(inner) => { let is_nullable = inner.is_nullable; let avro_inner = match avro_field { @@ -168,7 +167,7 @@ fn deserialize_value<'a>( } array.try_push_valid()?; }, - _ => match data_type.to_physical_type() { + _ => match dtype.to_physical_type() { PhysicalType::Boolean => { let is_valid = block[0] == 1; block = &block[1..]; @@ -331,7 +330,7 @@ fn skip_item<'a>( return Ok(block); } } - match &field.data_type { + match &field.dtype { ArrowDataType::List(inner) => { let avro_inner = match avro_field { AvroSchema::Array(inner) => inner.as_ref(), @@ -392,7 +391,7 @@ fn skip_item<'a>( block = skip_item(field, &avro_field.schema, block)?; } }, - _ => match field.data_type.to_physical_type() { + _ => match field.dtype.to_physical_type() { PhysicalType::Boolean => { let _ = block[0] == 1; block = &block[1..]; @@ -444,7 +443,7 @@ fn skip_item<'a>( block = &block[len..]; }, PhysicalType::FixedSizeBinary => { - let len = if let ArrowDataType::FixedSizeBinary(len) = &field.data_type { + let len = if let ArrowDataType::FixedSizeBinary(len) = &field.dtype { *len } else { unreachable!() @@ -467,7 +466,7 @@ fn skip_item<'a>( /// `fields`, `avro_fields` and `projection` must have the same length. pub fn deserialize( block: &Block, - fields: &[Field], + fields: &ArrowSchema, avro_fields: &[AvroField], projection: &[bool], ) -> PolarsResult>> { @@ -479,12 +478,12 @@ pub fn deserialize( // create mutables, one per field let mut arrays: Vec> = fields - .iter() + .iter_values() .zip(avro_fields.iter()) .zip(projection.iter()) .map(|((field, avro_field), projection)| { if *projection { - make_mutable(&field.data_type, Some(&avro_field.schema), rows) + make_mutable(&field.dtype, Some(&avro_field.schema), rows) } else { // just something; we are not going to use it make_mutable(&ArrowDataType::Int32, None, 0) @@ -496,7 +495,7 @@ pub fn deserialize( for _ in 0..rows { let iter = arrays .iter_mut() - .zip(fields.iter()) + .zip(fields.iter_values()) .zip(avro_fields.iter()) .zip(projection.iter()); diff --git a/crates/polars-arrow/src/io/avro/read/mod.rs b/crates/polars-arrow/src/io/avro/read/mod.rs index 701bf8a39579..fb321d5c8e6e 100644 --- a/crates/polars-arrow/src/io/avro/read/mod.rs +++ b/crates/polars-arrow/src/io/avro/read/mod.rs @@ -17,14 +17,14 @@ mod util; pub use schema::infer_schema; use crate::array::Array; -use crate::datatypes::Field; +use crate::datatypes::ArrowSchema; use crate::record_batch::RecordBatchT; /// Single threaded, blocking reader of Avro; [`Iterator`] of [`RecordBatchT`]. pub struct Reader { iter: BlockStreamingIterator, avro_fields: Vec, - fields: Vec, + fields: ArrowSchema, projection: Vec, } @@ -33,7 +33,7 @@ impl Reader { pub fn new( reader: R, metadata: FileMetadata, - fields: Vec, + fields: ArrowSchema, projection: Option>, ) -> Self { let projection = projection.unwrap_or_else(|| fields.iter().map(|_| true).collect()); @@ -56,7 +56,7 @@ impl Iterator for Reader { type Item = PolarsResult>>; fn next(&mut self) -> Option { - let fields = &self.fields[..]; + let fields = &self.fields; let avro_fields = &self.avro_fields; let projection = &self.projection; diff --git a/crates/polars-arrow/src/io/avro/read/nested.rs b/crates/polars-arrow/src/io/avro/read/nested.rs index 7188e06ae873..fc7e07487d83 100644 --- a/crates/polars-arrow/src/io/avro/read/nested.rs +++ b/crates/polars-arrow/src/io/avro/read/nested.rs @@ -8,22 +8,18 @@ use crate::offset::{Offset, Offsets}; /// Auxiliary struct #[derive(Debug)] pub struct DynMutableListArray { - data_type: ArrowDataType, + dtype: ArrowDataType, offsets: Offsets, values: Box, validity: Option, } impl DynMutableListArray { - pub fn new_from( - values: Box, - data_type: ArrowDataType, - capacity: usize, - ) -> Self { + pub fn new_from(values: Box, dtype: ArrowDataType, capacity: usize) -> Self { assert_eq!(values.len(), 0); - ListArray::::get_child_field(&data_type); + ListArray::::get_child_field(&dtype); Self { - data_type, + dtype, offsets: Offsets::::with_capacity(capacity), values, validity: None, @@ -80,7 +76,7 @@ impl MutableArray for DynMutableListArray { fn as_box(&mut self) -> Box { ListArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.offsets).into(), self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), @@ -90,7 +86,7 @@ impl MutableArray for DynMutableListArray { fn as_arc(&mut self) -> std::sync::Arc { ListArray::new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.offsets).into(), self.values.as_box(), std::mem::take(&mut self.validity).map(|x| x.into()), @@ -98,8 +94,8 @@ impl MutableArray for DynMutableListArray { .arced() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { @@ -126,7 +122,7 @@ impl MutableArray for DynMutableListArray { #[derive(Debug)] pub struct FixedItemsUtf8Dictionary { - data_type: ArrowDataType, + dtype: ArrowDataType, keys: MutablePrimitiveArray, values: Utf8Array, } @@ -134,9 +130,9 @@ pub struct FixedItemsUtf8Dictionary { impl FixedItemsUtf8Dictionary { pub fn with_capacity(values: Utf8Array, capacity: usize) -> Self { Self { - data_type: ArrowDataType::Dictionary( + dtype: ArrowDataType::Dictionary( IntegerType::Int32, - Box::new(values.data_type().clone()), + Box::new(values.dtype().clone()), false, ), keys: MutablePrimitiveArray::::with_capacity(capacity), @@ -166,7 +162,7 @@ impl MutableArray for FixedItemsUtf8Dictionary { fn as_box(&mut self) -> Box { Box::new( DictionaryArray::try_new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.keys).into(), Box::new(self.values.clone()), ) @@ -177,7 +173,7 @@ impl MutableArray for FixedItemsUtf8Dictionary { fn as_arc(&mut self) -> std::sync::Arc { std::sync::Arc::new( DictionaryArray::try_new( - self.data_type.clone(), + self.dtype.clone(), std::mem::take(&mut self.keys).into(), Box::new(self.values.clone()), ) @@ -185,8 +181,8 @@ impl MutableArray for FixedItemsUtf8Dictionary { ) } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { @@ -214,15 +210,15 @@ impl MutableArray for FixedItemsUtf8Dictionary { /// Auxiliary struct #[derive(Debug)] pub struct DynMutableStructArray { - data_type: ArrowDataType, + dtype: ArrowDataType, values: Vec>, validity: Option, } impl DynMutableStructArray { - pub fn new(values: Vec>, data_type: ArrowDataType) -> Self { + pub fn new(values: Vec>, dtype: ArrowDataType) -> Self { Self { - data_type, + dtype, values, validity: None, } @@ -273,7 +269,7 @@ impl MutableArray for DynMutableStructArray { let values = self.values.iter_mut().map(|x| x.as_box()).collect(); Box::new(StructArray::new( - self.data_type.clone(), + self.dtype.clone(), values, std::mem::take(&mut self.validity).map(|x| x.into()), )) @@ -283,14 +279,14 @@ impl MutableArray for DynMutableStructArray { let values = self.values.iter_mut().map(|x| x.as_box()).collect(); std::sync::Arc::new(StructArray::new( - self.data_type.clone(), + self.dtype.clone(), values, std::mem::take(&mut self.validity).map(|x| x.into()), )) } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-arrow/src/io/avro/read/schema.rs b/crates/polars-arrow/src/io/avro/read/schema.rs index a29402ae600f..ae9660496c7f 100644 --- a/crates/polars-arrow/src/io/avro/read/schema.rs +++ b/crates/polars-arrow/src/io/avro/read/schema.rs @@ -1,18 +1,22 @@ use avro_schema::schema::{Enum, Fixed, Record, Schema as AvroSchema}; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use crate::datatypes::*; fn external_props(schema: &AvroSchema) -> Metadata { let mut props = Metadata::new(); - match &schema { + match schema { AvroSchema::Record(Record { doc: Some(ref doc), .. }) | AvroSchema::Enum(Enum { doc: Some(ref doc), .. }) => { - props.insert("avro::doc".to_string(), doc.clone()); + props.insert( + PlSmallStr::from_static("avro::doc"), + PlSmallStr::from_str(doc.as_str()), + ); }, _ => {}, } @@ -22,18 +26,19 @@ fn external_props(schema: &AvroSchema) -> Metadata { /// Infers an [`ArrowSchema`] from the root [`Record`]. /// This pub fn infer_schema(record: &Record) -> PolarsResult { - Ok(record + record .fields .iter() .map(|field| { - schema_to_field( + let field = schema_to_field( &field.schema, Some(&field.name), external_props(&field.schema), - ) + )?; + + Ok((field.name.clone(), field)) }) - .collect::>>()? - .into()) + .collect::>() } fn schema_to_field( @@ -42,7 +47,7 @@ fn schema_to_field( props: Metadata, ) -> PolarsResult { let mut nullable = false; - let data_type = match schema { + let dtype = match schema { AvroSchema::Null => ArrowDataType::Null, AvroSchema::Boolean => ArrowDataType::Boolean, AvroSchema::Int(logical) => match logical { @@ -59,12 +64,14 @@ fn schema_to_field( avro_schema::schema::LongLogical::Time => { ArrowDataType::Time64(TimeUnit::Microsecond) }, - avro_schema::schema::LongLogical::TimestampMillis => { - ArrowDataType::Timestamp(TimeUnit::Millisecond, Some("00:00".to_string())) - }, - avro_schema::schema::LongLogical::TimestampMicros => { - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some("00:00".to_string())) - }, + avro_schema::schema::LongLogical::TimestampMillis => ArrowDataType::Timestamp( + TimeUnit::Millisecond, + Some(PlSmallStr::from_static("00:00")), + ), + avro_schema::schema::LongLogical::TimestampMicros => ArrowDataType::Timestamp( + TimeUnit::Microsecond, + Some(PlSmallStr::from_static("00:00")), + ), avro_schema::schema::LongLogical::LocalTimestampMillis => { ArrowDataType::Timestamp(TimeUnit::Millisecond, None) }, @@ -100,7 +107,7 @@ fn schema_to_field( .iter() .find(|&schema| !matches!(schema, AvroSchema::Null)) { - schema_to_field(schema, None, Metadata::default())?.data_type + schema_to_field(schema, None, Metadata::default())?.dtype } else { polars_bail!(nyi = "Can't read avro union {schema:?}"); } @@ -118,7 +125,10 @@ fn schema_to_field( .map(|field| { let mut props = Metadata::new(); if let Some(doc) = &field.doc { - props.insert("avro::doc".to_string(), doc.clone()); + props.insert( + PlSmallStr::from_static("avro::doc"), + PlSmallStr::from_str(doc), + ); } schema_to_field(&field.schema, Some(&field.name), props) }) @@ -127,7 +137,7 @@ fn schema_to_field( }, AvroSchema::Enum { .. } => { return Ok(Field::new( - name.unwrap_or_default(), + PlSmallStr::from_str(name.unwrap_or_default()), ArrowDataType::Dictionary(IntegerType::Int32, Box::new(ArrowDataType::Utf8), false), false, )) @@ -147,5 +157,5 @@ fn schema_to_field( let name = name.unwrap_or_default(); - Ok(Field::new(name, data_type, nullable).with_metadata(props)) + Ok(Field::new(PlSmallStr::from_str(name), dtype, nullable).with_metadata(props)) } diff --git a/crates/polars-arrow/src/io/avro/write/schema.rs b/crates/polars-arrow/src/io/avro/write/schema.rs index 8171798a692c..e0d71c5611c3 100644 --- a/crates/polars-arrow/src/io/avro/write/schema.rs +++ b/crates/polars-arrow/src/io/avro/write/schema.rs @@ -10,8 +10,7 @@ use crate::datatypes::*; pub fn to_record(schema: &ArrowSchema, name: String) -> PolarsResult { let mut name_counter: i32 = 0; let fields = schema - .fields - .iter() + .iter_values() .map(|f| field_to_field(f, &mut name_counter)) .collect::>()?; Ok(Record { @@ -24,22 +23,22 @@ pub fn to_record(schema: &ArrowSchema, name: String) -> PolarsResult { } fn field_to_field(field: &Field, name_counter: &mut i32) -> PolarsResult { - let schema = type_to_schema(field.data_type(), field.is_nullable, name_counter)?; - Ok(AvroField::new(&field.name, schema)) + let schema = type_to_schema(field.dtype(), field.is_nullable, name_counter)?; + Ok(AvroField::new(field.name.to_string(), schema)) } fn type_to_schema( - data_type: &ArrowDataType, + dtype: &ArrowDataType, is_nullable: bool, name_counter: &mut i32, ) -> PolarsResult { Ok(if is_nullable { AvroSchema::Union(vec![ AvroSchema::Null, - _type_to_schema(data_type, name_counter)?, + _type_to_schema(dtype, name_counter)?, ]) } else { - _type_to_schema(data_type, name_counter)? + _type_to_schema(dtype, name_counter)? }) } @@ -48,8 +47,8 @@ fn _get_field_name(name_counter: &mut i32) -> String { format!("r{name_counter}") } -fn _type_to_schema(data_type: &ArrowDataType, name_counter: &mut i32) -> PolarsResult { - Ok(match data_type.to_logical_type() { +fn _type_to_schema(dtype: &ArrowDataType, name_counter: &mut i32) -> PolarsResult { + Ok(match dtype.to_logical_type() { ArrowDataType::Null => AvroSchema::Null, ArrowDataType::Boolean => AvroSchema::Boolean, ArrowDataType::Int32 => AvroSchema::Int(None), @@ -62,7 +61,7 @@ fn _type_to_schema(data_type: &ArrowDataType, name_counter: &mut i32) -> PolarsR ArrowDataType::LargeUtf8 => AvroSchema::String(None), ArrowDataType::LargeList(inner) | ArrowDataType::List(inner) => { AvroSchema::Array(Box::new(type_to_schema( - &inner.data_type, + &inner.dtype, inner.is_nullable, name_counter, )?)) diff --git a/crates/polars-arrow/src/io/avro/write/serialize.rs b/crates/polars-arrow/src/io/avro/write/serialize.rs index 36519acbf493..ba287521b677 100644 --- a/crates/polars-arrow/src/io/avro/write/serialize.rs +++ b/crates/polars-arrow/src/io/avro/write/serialize.rs @@ -207,14 +207,14 @@ fn struct_optional<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer /// Creates a [`StreamingIterator`] trait object that presents items from `array` /// encoded according to `schema`. /// # Panic -/// This function panics iff the `data_type` is not supported (use [`can_serialize`] to check) +/// This function panics iff the `dtype` is not supported (use [`can_serialize`] to check) /// # Implementation /// This function performs minimal CPU work: it dynamically dispatches based on the schema /// and arrow type. pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSerializer<'a> { - let data_type = array.data_type().to_physical_type(); + let dtype = array.dtype().to_physical_type(); - match (data_type, schema) { + match (dtype, schema) { (PhysicalType::Boolean, AvroSchema::Boolean) => { let values = array.as_any().downcast_ref::().unwrap(); Box::new(BufStreamingIterator::new( @@ -497,18 +497,18 @@ pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSeria } } -/// Whether [`new_serializer`] supports `data_type`. -pub fn can_serialize(data_type: &ArrowDataType) -> bool { +/// Whether [`new_serializer`] supports `dtype`. +pub fn can_serialize(dtype: &ArrowDataType) -> bool { use ArrowDataType::*; - match data_type.to_logical_type() { - List(inner) => return can_serialize(&inner.data_type), - LargeList(inner) => return can_serialize(&inner.data_type), - Struct(inner) => return inner.iter().all(|inner| can_serialize(&inner.data_type)), + match dtype.to_logical_type() { + List(inner) => return can_serialize(&inner.dtype), + LargeList(inner) => return can_serialize(&inner.dtype), + Struct(inner) => return inner.iter().all(|inner| can_serialize(&inner.dtype)), _ => {}, }; matches!( - data_type, + dtype, Boolean | Int32 | Int64 diff --git a/crates/polars-arrow/src/io/flight/mod.rs b/crates/polars-arrow/src/io/flight/mod.rs index 85a2be24ea13..c02a4889f7bb 100644 --- a/crates/polars-arrow/src/io/flight/mod.rs +++ b/crates/polars-arrow/src/io/flight/mod.rs @@ -79,7 +79,7 @@ pub fn serialize_schema_to_info( let encoded_data = if let Some(ipc_fields) = ipc_fields { schema_as_encoded_data(schema, ipc_fields) } else { - let ipc_fields = default_ipc_fields(&schema.fields); + let ipc_fields = default_ipc_fields(schema.iter_values()); schema_as_encoded_data(schema, &ipc_fields) }; @@ -92,7 +92,7 @@ fn _serialize_schema(schema: &ArrowSchema, ipc_fields: Option<&[IpcField]>) -> V if let Some(ipc_fields) = ipc_fields { write::schema_to_bytes(schema, ipc_fields) } else { - let ipc_fields = default_ipc_fields(&schema.fields); + let ipc_fields = default_ipc_fields(schema.iter_values()); write::schema_to_bytes(schema, &ipc_fields) } } @@ -113,7 +113,7 @@ pub fn deserialize_schemas(bytes: &[u8]) -> PolarsResult<(ArrowSchema, IpcSchema /// Deserializes [`FlightData`] representing a record batch message to [`RecordBatchT`]. pub fn deserialize_batch( data: &FlightData, - fields: &[Field], + fields: &ArrowSchema, ipc_schema: &IpcSchema, dictionaries: &read::Dictionaries, ) -> PolarsResult>> { @@ -147,7 +147,7 @@ pub fn deserialize_batch( /// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`. pub fn deserialize_dictionary( data: &FlightData, - fields: &[Field], + fields: &ArrowSchema, ipc_schema: &IpcSchema, dictionaries: &mut read::Dictionaries, ) -> PolarsResult<()> { @@ -182,7 +182,7 @@ pub fn deserialize_dictionary( /// or by upserting into `dictionaries` (when the message is a dictionary) pub fn deserialize_message( data: &FlightData, - fields: &[Field], + fields: &ArrowSchema, ipc_schema: &IpcSchema, dictionaries: &mut Dictionaries, ) -> PolarsResult>>> { diff --git a/crates/polars-arrow/src/io/ipc/read/array/binary.rs b/crates/polars-arrow/src/io/ipc/read/array/binary.rs index 9553212ec5c4..d46f5ca102ac 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/binary.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/binary.rs @@ -14,7 +14,7 @@ use crate::offset::Offset; #[allow(clippy::too_many_arguments)] pub fn read_binary( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, buffers: &mut VecDeque, reader: &mut R, block_offset: u64, @@ -23,7 +23,7 @@ pub fn read_binary( limit: Option, scratch: &mut Vec, ) -> PolarsResult> { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -61,7 +61,7 @@ pub fn read_binary( scratch, )?; - BinaryArray::::try_new(data_type, offsets.try_into()?, values, validity) + BinaryArray::::try_new(dtype, offsets.try_into()?, values, validity) } pub fn skip_binary( diff --git a/crates/polars-arrow/src/io/ipc/read/array/binview.rs b/crates/polars-arrow/src/io/ipc/read/array/binview.rs index 8d5725023791..4423cdaab6e4 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/binview.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/binview.rs @@ -12,7 +12,7 @@ use crate::buffer::Buffer; pub fn read_binview( field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, buffers: &mut VecDeque, reader: &mut R, block_offset: u64, @@ -21,7 +21,7 @@ pub fn read_binview( limit: Option, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -62,7 +62,7 @@ pub fn read_binview( }) .collect::>>>()?; - BinaryViewArrayGeneric::::try_new(data_type, views, Arc::from(variadic_buffers), validity) + BinaryViewArrayGeneric::::try_new(dtype, views, Arc::from(variadic_buffers), validity) .map(|arr| arr.boxed()) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/boolean.rs b/crates/polars-arrow/src/io/ipc/read/array/boolean.rs index 16443b0b8af0..ebc9ed510380 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/boolean.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/boolean.rs @@ -12,7 +12,7 @@ use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; #[allow(clippy::too_many_arguments)] pub fn read_boolean( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, buffers: &mut VecDeque, reader: &mut R, block_offset: u64, @@ -21,7 +21,7 @@ pub fn read_boolean( limit: Option, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -45,7 +45,7 @@ pub fn read_boolean( compression, scratch, )?; - BooleanArray::try_new(data_type, values, validity) + BooleanArray::try_new(dtype, values, validity) } pub fn skip_boolean( diff --git a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs index 5a43fe21e102..88f9ef46de89 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs @@ -1,8 +1,8 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; -use ahash::HashSet; use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::aliases::PlHashSet; use super::super::{Compression, Dictionaries, IpcBuffer, Node}; use super::{read_primitive, skip_primitive}; @@ -12,7 +12,7 @@ use crate::datatypes::ArrowDataType; #[allow(clippy::too_many_arguments)] pub fn read_dictionary( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, id: Option, buffers: &mut VecDeque, reader: &mut R, @@ -34,7 +34,7 @@ where let values = dictionaries .get(&id) .ok_or_else(|| { - let valid_ids = dictionaries.keys().collect::>(); + let valid_ids = dictionaries.keys().collect::>(); polars_err!(ComputeError: "Dictionary id {id} not found. Valid ids: {valid_ids:?}" ) @@ -53,7 +53,7 @@ where scratch, )?; - DictionaryArray::::try_new(data_type, keys, values) + DictionaryArray::::try_new(dtype, keys, values) } pub fn skip_dictionary( diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs index 9683952c6d6c..61a8055528e7 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs @@ -12,7 +12,7 @@ use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; #[allow(clippy::too_many_arguments)] pub fn read_fixed_size_binary( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, buffers: &mut VecDeque, reader: &mut R, block_offset: u64, @@ -21,7 +21,7 @@ pub fn read_fixed_size_binary( limit: Option, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -36,7 +36,7 @@ pub fn read_fixed_size_binary( let length = try_get_array_length(field_node, limit)?; - let length = length.saturating_mul(FixedSizeBinaryArray::maybe_get_size(&data_type)?); + let length = length.saturating_mul(FixedSizeBinaryArray::maybe_get_size(&dtype)?); let values = read_buffer( buffers, length, @@ -47,7 +47,7 @@ pub fn read_fixed_size_binary( scratch, )?; - FixedSizeBinaryArray::try_new(data_type, values, validity) + FixedSizeBinaryArray::try_new(dtype, values, validity) } pub fn skip_fixed_size_binary( diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs index 1f303a156787..eac68f9fda54 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -15,7 +15,7 @@ use crate::io::ipc::read::array::try_get_field_node; pub fn read_fixed_size_list( field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, ipc_field: &IpcField, buffers: &mut VecDeque, reader: &mut R, @@ -27,7 +27,7 @@ pub fn read_fixed_size_list( version: Version, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -40,7 +40,7 @@ pub fn read_fixed_size_list( scratch, )?; - let (field, size) = FixedSizeListArray::get_child_and_size(&data_type); + let (field, size) = FixedSizeListArray::get_child_and_size(&dtype); let limit = limit.map(|x| x.saturating_mul(size)); @@ -59,12 +59,12 @@ pub fn read_fixed_size_list( version, scratch, )?; - FixedSizeListArray::try_new(data_type, values, validity) + FixedSizeListArray::try_new(dtype, values, validity) } pub fn skip_fixed_size_list( field_nodes: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, buffers: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { @@ -78,12 +78,7 @@ pub fn skip_fixed_size_list( .pop_front() .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; - let (field, _) = FixedSizeListArray::get_child_and_size(data_type); + let (field, _) = FixedSizeListArray::get_child_and_size(dtype); - skip( - field_nodes, - field.data_type(), - buffers, - variadic_buffer_counts, - ) + skip(field_nodes, field.dtype(), buffers, variadic_buffer_counts) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/list.rs b/crates/polars-arrow/src/io/ipc/read/array/list.rs index 45566fd5df9f..b89a03cf552a 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/list.rs @@ -17,7 +17,7 @@ use crate::offset::Offset; pub fn read_list( field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, ipc_field: &IpcField, buffers: &mut VecDeque, reader: &mut R, @@ -32,7 +32,7 @@ pub fn read_list( where Vec: TryInto, { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -61,7 +61,7 @@ where let last_offset = offsets.last().unwrap().to_usize(); - let field = ListArray::::get_child_field(&data_type); + let field = ListArray::::get_child_field(&dtype); let values = read( field_nodes, @@ -78,12 +78,12 @@ where version, scratch, )?; - ListArray::try_new(data_type, offsets.try_into()?, values, validity) + ListArray::try_new(dtype, offsets.try_into()?, values, validity) } pub fn skip_list( field_nodes: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, buffers: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { @@ -100,7 +100,7 @@ pub fn skip_list( .pop_front() .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; - let data_type = ListArray::::get_child_type(data_type); + let dtype = ListArray::::get_child_type(dtype); - skip(field_nodes, data_type, buffers, variadic_buffer_counts) + skip(field_nodes, dtype, buffers, variadic_buffer_counts) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/map.rs b/crates/polars-arrow/src/io/ipc/read/array/map.rs index 741d496a5a63..17e963f5dfa4 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/map.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/map.rs @@ -16,7 +16,7 @@ use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; pub fn read_map( field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, ipc_field: &IpcField, buffers: &mut VecDeque, reader: &mut R, @@ -28,7 +28,7 @@ pub fn read_map( version: Version, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -55,7 +55,7 @@ pub fn read_map( // Older versions of the IPC format sometimes do not report an offset .or_else(|_| PolarsResult::Ok(Buffer::::from(vec![0i32])))?; - let field = MapArray::get_field(&data_type); + let field = MapArray::get_field(&dtype); let last_offset: usize = offsets.last().copied().unwrap() as usize; @@ -74,12 +74,12 @@ pub fn read_map( version, scratch, )?; - MapArray::try_new(data_type, offsets.try_into()?, field, validity) + MapArray::try_new(dtype, offsets.try_into()?, field, validity) } pub fn skip_map( field_nodes: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, buffers: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { @@ -96,7 +96,7 @@ pub fn skip_map( .pop_front() .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; - let data_type = MapArray::get_field(data_type).data_type(); + let dtype = MapArray::get_field(dtype).dtype(); - skip(field_nodes, data_type, buffers, variadic_buffer_counts) + skip(field_nodes, dtype, buffers, variadic_buffer_counts) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/mod.rs b/crates/polars-arrow/src/io/ipc/read/array/mod.rs index 2ffe1a369c25..21c393a2869e 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/mod.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/mod.rs @@ -34,10 +34,10 @@ use crate::datatypes::ArrowDataType; fn try_get_field_node<'a>( field_nodes: &mut VecDeque>, - data_type: &ArrowDataType, + dtype: &ArrowDataType, ) -> PolarsResult> { field_nodes.pop_front().ok_or_else(|| { - polars_err!(ComputeError: "IPC: unable to fetch the field for {:?}\n\nThe file or stream is corrupted.", data_type) + polars_err!(ComputeError: "IPC: unable to fetch the field for {:?}\n\nThe file or stream is corrupted.", dtype) }) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/null.rs b/crates/polars-arrow/src/io/ipc/read/array/null.rs index f9df4d254900..6fac4ae2d7bb 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/null.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/null.rs @@ -9,14 +9,14 @@ use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; pub fn read_null( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, limit: Option, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let length = try_get_array_length(field_node, limit)?; - NullArray::try_new(data_type, length) + NullArray::try_new(dtype, length) } pub fn skip_null(field_nodes: &mut VecDeque) -> PolarsResult<()> { diff --git a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs index 04304aadca90..a530cba97bb1 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs @@ -13,7 +13,7 @@ use crate::types::NativeType; #[allow(clippy::too_many_arguments)] pub fn read_primitive( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, buffers: &mut VecDeque, reader: &mut R, block_offset: u64, @@ -25,7 +25,7 @@ pub fn read_primitive( where Vec: TryInto, { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -49,7 +49,7 @@ where compression, scratch, )?; - PrimitiveArray::::try_new(data_type, values, validity) + PrimitiveArray::::try_new(dtype, values, validity) } pub fn skip_primitive( diff --git a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs index 6dc716ab368b..5cf68f1d1d95 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs @@ -15,7 +15,7 @@ use crate::io::ipc::read::array::try_get_field_node; pub fn read_struct( field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, ipc_field: &IpcField, buffers: &mut VecDeque, reader: &mut R, @@ -27,7 +27,7 @@ pub fn read_struct( version: Version, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -40,7 +40,7 @@ pub fn read_struct( scratch, )?; - let fields = StructArray::get_fields(&data_type); + let fields = StructArray::get_fields(&dtype); let values = fields .iter() @@ -64,12 +64,12 @@ pub fn read_struct( }) .collect::>>()?; - StructArray::try_new(data_type, values, validity) + StructArray::try_new(dtype, values, validity) } pub fn skip_struct( field_nodes: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, buffers: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { @@ -83,14 +83,9 @@ pub fn skip_struct( .pop_front() .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; - let fields = StructArray::get_fields(data_type); + let fields = StructArray::get_fields(dtype); - fields.iter().try_for_each(|field| { - skip( - field_nodes, - field.data_type(), - buffers, - variadic_buffer_counts, - ) - }) + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.dtype(), buffers, variadic_buffer_counts)) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/union.rs b/crates/polars-arrow/src/io/ipc/read/array/union.rs index 192d9582ed21..b84ff3349aed 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/union.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/union.rs @@ -16,7 +16,7 @@ use crate::io::ipc::read::array::{try_get_array_length, try_get_field_node}; pub fn read_union( field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, ipc_field: &IpcField, buffers: &mut VecDeque, reader: &mut R, @@ -28,7 +28,7 @@ pub fn read_union( version: Version, scratch: &mut Vec, ) -> PolarsResult { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; if version != Version::V5 { let _ = buffers @@ -48,7 +48,7 @@ pub fn read_union( scratch, )?; - let offsets = if let ArrowDataType::Union(_, _, mode) = data_type { + let offsets = if let ArrowDataType::Union(_, _, mode) = dtype { if !mode.is_sparse() { Some(read_buffer( buffers, @@ -66,7 +66,7 @@ pub fn read_union( unreachable!() }; - let fields = UnionArray::get_fields(&data_type); + let fields = UnionArray::get_fields(&dtype); let fields = fields .iter() @@ -90,12 +90,12 @@ pub fn read_union( }) .collect::>>()?; - UnionArray::try_new(data_type, types, fields, offsets) + UnionArray::try_new(dtype, types, fields, offsets) } pub fn skip_union( field_nodes: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, buffers: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { @@ -108,7 +108,7 @@ pub fn skip_union( let _ = buffers .pop_front() .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; - if let ArrowDataType::Union(_, _, Dense) = data_type { + if let ArrowDataType::Union(_, _, Dense) = dtype { let _ = buffers .pop_front() .ok_or_else(|| polars_err!(oos = "IPC: missing offsets buffer."))?; @@ -116,14 +116,9 @@ pub fn skip_union( unreachable!() }; - let fields = UnionArray::get_fields(data_type); + let fields = UnionArray::get_fields(dtype); - fields.iter().try_for_each(|field| { - skip( - field_nodes, - field.data_type(), - buffers, - variadic_buffer_counts, - ) - }) + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.dtype(), buffers, variadic_buffer_counts)) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs index f29f8d8cdb26..33f43baf7f1b 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs @@ -11,7 +11,7 @@ use crate::offset::Offset; #[allow(clippy::too_many_arguments)] pub fn read_utf8( field_nodes: &mut VecDeque, - data_type: ArrowDataType, + dtype: ArrowDataType, buffers: &mut VecDeque, reader: &mut R, block_offset: u64, @@ -20,7 +20,7 @@ pub fn read_utf8( limit: Option, scratch: &mut Vec, ) -> PolarsResult> { - let field_node = try_get_field_node(field_nodes, &data_type)?; + let field_node = try_get_field_node(field_nodes, &dtype)?; let validity = read_validity( buffers, @@ -58,7 +58,7 @@ pub fn read_utf8( scratch, )?; - Utf8Array::::try_new(data_type, offsets.try_into()?, values, validity) + Utf8Array::::try_new(dtype, offsets.try_into()?, values, validity) } pub fn skip_utf8( diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index f2ab79b56ad9..2458cba702b9 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -1,13 +1,14 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; -use ahash::AHashMap; use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::aliases::PlHashMap; +use polars_utils::pl_str::PlSmallStr; use super::deserialize::{read, skip}; use super::Dictionaries; use crate::array::*; -use crate::datatypes::{ArrowDataType, Field}; +use crate::datatypes::{ArrowDataType, ArrowSchema, Field}; use crate::io::ipc::read::OutOfSpecKind; use crate::io::ipc::{IpcField, IpcSchema}; use crate::record_batch::RecordBatchT; @@ -76,7 +77,7 @@ impl<'a, A, I: Iterator> Iterator for ProjectionIter<'a, A, I> { #[allow(clippy::too_many_arguments)] pub fn read_record_batch( batch: arrow_format::ipc::RecordBatchRef, - fields: &[Field], + fields: &ArrowSchema, ipc_schema: &IpcSchema, projection: Option<&[usize]>, limit: Option, @@ -126,8 +127,10 @@ pub fn read_record_batch( let mut field_nodes = field_nodes.iter().collect::>(); let columns = if let Some(projection) = projection { - let projection = - ProjectionIter::new(projection, fields.iter().zip(ipc_schema.fields.iter())); + let projection = ProjectionIter::new( + projection, + fields.iter_values().zip(ipc_schema.fields.iter()), + ); projection .map(|maybe_field| match maybe_field { @@ -151,7 +154,7 @@ pub fn read_record_batch( ProjectionResult::NotSelected((field, _)) => { skip( &mut field_nodes, - &field.data_type, + &field.dtype, &mut buffers, &mut variadic_buffer_counts, )?; @@ -162,7 +165,7 @@ pub fn read_record_batch( .collect::>>()? } else { fields - .iter() + .iter_values() .zip(ipc_schema.fields.iter()) .map(|(field, ipc_field)| { read( @@ -190,11 +193,11 @@ pub fn read_record_batch( fn find_first_dict_field_d<'a>( id: i64, - data_type: &'a ArrowDataType, + dtype: &'a ArrowDataType, ipc_field: &'a IpcField, ) -> Option<(&'a Field, &'a IpcField)> { use ArrowDataType::*; - match data_type { + match dtype { Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field), List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => { find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0]) @@ -221,16 +224,16 @@ fn find_first_dict_field<'a>( return Some((field, ipc_field)); } } - find_first_dict_field_d(id, &field.data_type, ipc_field) + find_first_dict_field_d(id, &field.dtype, ipc_field) } pub(crate) fn first_dict_field<'a>( id: i64, - fields: &'a [Field], + fields: &'a ArrowSchema, ipc_fields: &'a [IpcField], ) -> PolarsResult<(&'a Field, &'a IpcField)> { assert_eq!(fields.len(), ipc_fields.len()); - for (field, ipc_field) in fields.iter().zip(ipc_fields.iter()) { + for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) { if let Some(field) = find_first_dict_field(id, field, ipc_field) { return Ok(field); } @@ -245,7 +248,7 @@ pub(crate) fn first_dict_field<'a>( #[allow(clippy::too_many_arguments)] pub fn read_dictionary( batch: arrow_format::ipc::DictionaryBatchRef, - fields: &[Field], + fields: &ArrowSchema, ipc_schema: &IpcSchema, dictionaries: &mut Dictionaries, reader: &mut R, @@ -270,16 +273,19 @@ pub fn read_dictionary( .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))? .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?; - let value_type = if let ArrowDataType::Dictionary(_, value_type, _) = - first_field.data_type.to_logical_type() - { - value_type.as_ref() - } else { - polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id }) - }; + let value_type = + if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_logical_type() { + value_type.as_ref() + } else { + polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id }) + }; // Make a fake schema for the dictionary batch. - let fields = vec![Field::new("", value_type.clone(), false)]; + let fields = std::iter::once(( + PlSmallStr::EMPTY, + Field::new(PlSmallStr::EMPTY, value_type.clone(), false), + )) + .collect(); let ipc_schema = IpcSchema { fields: vec![first_ipc_field.clone()], is_little_endian: ipc_schema.is_little_endian, @@ -304,16 +310,22 @@ pub fn read_dictionary( } pub fn prepare_projection( - fields: &[Field], + schema: &ArrowSchema, mut projection: Vec, -) -> (Vec, AHashMap, Vec) { - let fields = projection.iter().map(|x| fields[*x].clone()).collect(); +) -> (Vec, PlHashMap, ArrowSchema) { + let schema = projection + .iter() + .map(|x| { + let (k, v) = schema.get_at_index(*x).unwrap(); + (k.clone(), v.clone()) + }) + .collect(); // todo: find way to do this more efficiently let mut indices = (0..projection.len()).collect::>(); indices.sort_unstable_by_key(|&i| &projection[i]); let map = indices.iter().copied().enumerate().fold( - AHashMap::default(), + PlHashMap::default(), |mut acc, (index, new_index)| { acc.insert(index, new_index); acc @@ -334,12 +346,12 @@ pub fn prepare_projection( } } - (projection, map, fields) + (projection, map, schema) } pub fn apply_projection( chunk: RecordBatchT>, - map: &AHashMap, + map: &PlHashMap, ) -> RecordBatchT> { // re-order according to projection let arrays = chunk.into_arrays(); diff --git a/crates/polars-arrow/src/io/ipc/read/deserialize.rs b/crates/polars-arrow/src/io/ipc/read/deserialize.rs index f27d9b58100e..1a57ac487c70 100644 --- a/crates/polars-arrow/src/io/ipc/read/deserialize.rs +++ b/crates/polars-arrow/src/io/ipc/read/deserialize.rs @@ -28,13 +28,13 @@ pub fn read( scratch: &mut Vec, ) -> PolarsResult> { use PhysicalType::*; - let data_type = field.data_type.clone(); + let dtype = field.dtype.clone(); - match data_type.to_physical_type() { - Null => read_null(field_nodes, data_type, limit).map(|x| x.boxed()), + match dtype.to_physical_type() { + Null => read_null(field_nodes, dtype, limit).map(|x| x.boxed()), Boolean => read_boolean( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -47,7 +47,7 @@ pub fn read( Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { read_primitive::<$T, _>( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -60,7 +60,7 @@ pub fn read( }), Binary => read_binary::( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -72,7 +72,7 @@ pub fn read( .map(|x| x.boxed()), LargeBinary => read_binary::( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -84,7 +84,7 @@ pub fn read( .map(|x| x.boxed()), FixedSizeBinary => read_fixed_size_binary( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -96,7 +96,7 @@ pub fn read( .map(|x| x.boxed()), Utf8 => read_utf8::( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -108,7 +108,7 @@ pub fn read( .map(|x| x.boxed()), LargeUtf8 => read_utf8::( field_nodes, - data_type, + dtype, buffers, reader, block_offset, @@ -121,7 +121,7 @@ pub fn read( List => read_list::( field_nodes, variadic_buffer_counts, - data_type, + dtype, ipc_field, buffers, reader, @@ -137,7 +137,7 @@ pub fn read( LargeList => read_list::( field_nodes, variadic_buffer_counts, - data_type, + dtype, ipc_field, buffers, reader, @@ -153,7 +153,7 @@ pub fn read( FixedSizeList => read_fixed_size_list( field_nodes, variadic_buffer_counts, - data_type, + dtype, ipc_field, buffers, reader, @@ -169,7 +169,7 @@ pub fn read( Struct => read_struct( field_nodes, variadic_buffer_counts, - data_type, + dtype, ipc_field, buffers, reader, @@ -186,7 +186,7 @@ pub fn read( match_integer_type!(key_type, |$T| { read_dictionary::<$T, _>( field_nodes, - data_type, + dtype, ipc_field.dictionary_id, buffers, reader, @@ -203,7 +203,7 @@ pub fn read( Union => read_union( field_nodes, variadic_buffer_counts, - data_type, + dtype, ipc_field, buffers, reader, @@ -219,7 +219,7 @@ pub fn read( Map => read_map( field_nodes, variadic_buffer_counts, - data_type, + dtype, ipc_field, buffers, reader, @@ -235,7 +235,7 @@ pub fn read( Utf8View => read_binview::( field_nodes, variadic_buffer_counts, - data_type, + dtype, buffers, reader, block_offset, @@ -247,7 +247,7 @@ pub fn read( BinaryView => read_binview::<[u8], _>( field_nodes, variadic_buffer_counts, - data_type, + dtype, buffers, reader, block_offset, @@ -261,27 +261,25 @@ pub fn read( pub fn skip( field_nodes: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, buffers: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { use PhysicalType::*; - match data_type.to_physical_type() { + match dtype.to_physical_type() { Null => skip_null(field_nodes), Boolean => skip_boolean(field_nodes, buffers), Primitive(_) => skip_primitive(field_nodes, buffers), LargeBinary | Binary => skip_binary(field_nodes, buffers), LargeUtf8 | Utf8 => skip_utf8(field_nodes, buffers), FixedSizeBinary => skip_fixed_size_binary(field_nodes, buffers), - List => skip_list::(field_nodes, data_type, buffers, variadic_buffer_counts), - LargeList => skip_list::(field_nodes, data_type, buffers, variadic_buffer_counts), - FixedSizeList => { - skip_fixed_size_list(field_nodes, data_type, buffers, variadic_buffer_counts) - }, - Struct => skip_struct(field_nodes, data_type, buffers, variadic_buffer_counts), + List => skip_list::(field_nodes, dtype, buffers, variadic_buffer_counts), + LargeList => skip_list::(field_nodes, dtype, buffers, variadic_buffer_counts), + FixedSizeList => skip_fixed_size_list(field_nodes, dtype, buffers, variadic_buffer_counts), + Struct => skip_struct(field_nodes, dtype, buffers, variadic_buffer_counts), Dictionary(_) => skip_dictionary(field_nodes, buffers), - Union => skip_union(field_nodes, data_type, buffers, variadic_buffer_counts), - Map => skip_map(field_nodes, data_type, buffers, variadic_buffer_counts), + Union => skip_union(field_nodes, dtype, buffers, variadic_buffer_counts), + Map => skip_map(field_nodes, dtype, buffers, variadic_buffer_counts), BinaryView | Utf8View => skip_binview(field_nodes, buffers, variadic_buffer_counts), } } diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index c873060969d1..6c831064d5a1 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -89,7 +89,7 @@ fn read_dictionary_block( read_dictionary( batch, - &metadata.schema.fields, + &metadata.schema, &metadata.ipc_schema, dictionaries, reader, @@ -317,7 +317,7 @@ pub fn read_batch( read_record_batch( batch, - &metadata.schema.fields, + &metadata.schema, &metadata.ipc_schema, projection, limit, diff --git a/crates/polars-arrow/src/io/ipc/read/file_async.rs b/crates/polars-arrow/src/io/ipc/read/file_async.rs index 31f4ceef7e9b..567a58c1a1fb 100644 --- a/crates/polars-arrow/src/io/ipc/read/file_async.rs +++ b/crates/polars-arrow/src/io/ipc/read/file_async.rs @@ -1,18 +1,18 @@ //! Async reader for Arrow IPC files use std::io::SeekFrom; -use ahash::AHashMap; use arrow_format::ipc::planus::ReadAsRoot; use arrow_format::ipc::{Block, MessageHeaderRef}; use futures::stream::BoxStream; use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt}; use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::aliases::PlHashMap; use super::common::{apply_projection, prepare_projection, read_dictionary, read_record_batch}; use super::file::{deserialize_footer, get_record_batch}; use super::{Dictionaries, FileMetadata, OutOfSpecKind}; use crate::array::*; -use crate::datatypes::{ArrowSchema, Field}; +use crate::datatypes::ArrowSchema; use crate::io::ipc::{IpcSchema, ARROW_MAGIC_V2, CONTINUATION_MARKER}; use crate::record_batch::RecordBatchT; @@ -38,11 +38,7 @@ impl<'a> FileStream<'a> { R: AsyncRead + AsyncSeek + Unpin + Send + 'a, { let (projection, schema) = if let Some(projection) = projection { - let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); - let schema = ArrowSchema { - fields, - metadata: metadata.schema.metadata.clone(), - }; + let (p, h, schema) = prepare_projection(&metadata.schema, projection); (Some((p, h)), Some(schema)) } else { (None, None) @@ -70,7 +66,7 @@ impl<'a> FileStream<'a> { mut reader: R, mut dictionaries: Option, metadata: FileMetadata, - projection: Option<(Vec, AHashMap)>, + projection: Option<(Vec, PlHashMap)>, limit: Option, ) -> BoxStream<'a, PolarsResult>>> where @@ -221,7 +217,7 @@ where read_record_batch( batch, - &metadata.schema.fields, + &metadata.schema, &metadata.ipc_schema, projection, limit, @@ -238,7 +234,7 @@ where async fn read_dictionaries( mut reader: R, - fields: &[Field], + fields: &ArrowSchema, ipc_schema: &IpcSchema, blocks: &[Block], scratch: &mut Vec, @@ -334,14 +330,15 @@ async fn cached_read_dictionaries( ) -> PolarsResult<()> { match (&dictionaries, metadata.dictionaries.as_deref()) { (None, Some(blocks)) => { - let new_dictionaries = read_dictionaries( - reader, - &metadata.schema.fields, - &metadata.ipc_schema, - blocks, - &mut Default::default(), - ) - .await?; + let new_dictionaries: hashbrown::HashMap, ahash::RandomState> = + read_dictionaries( + reader, + &metadata.schema, + &metadata.ipc_schema, + blocks, + &mut Default::default(), + ) + .await?; *dictionaries = Some(new_dictionaries); }, (None, None) => { diff --git a/crates/polars-arrow/src/io/ipc/read/reader.rs b/crates/polars-arrow/src/io/ipc/read/reader.rs index 7befc8dfcc23..8369d2960233 100644 --- a/crates/polars-arrow/src/io/ipc/read/reader.rs +++ b/crates/polars-arrow/src/io/ipc/read/reader.rs @@ -1,7 +1,7 @@ use std::io::{Read, Seek}; -use ahash::AHashMap; use polars_error::PolarsResult; +use polars_utils::aliases::PlHashMap; use super::common::*; use super::{read_batch, read_file_dictionaries, Dictionaries, FileMetadata}; @@ -16,7 +16,7 @@ pub struct FileReader { // the dictionaries are going to be read dictionaries: Option, current_block: usize, - projection: Option<(Vec, AHashMap, ArrowSchema)>, + projection: Option<(Vec, PlHashMap, ArrowSchema)>, remaining: usize, data_scratch: Vec, message_scratch: Vec, @@ -33,11 +33,7 @@ impl FileReader { limit: Option, ) -> Self { let projection = projection.map(|projection| { - let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); - let schema = ArrowSchema { - fields, - metadata: metadata.schema.metadata.clone(), - }; + let (p, h, schema) = prepare_projection(&metadata.schema, projection); (p, h, schema) }); Self { diff --git a/crates/polars-arrow/src/io/ipc/read/schema.rs b/crates/polars-arrow/src/io/ipc/read/schema.rs index a6c1743e6a0b..7fe6141e9b14 100644 --- a/crates/polars-arrow/src/io/ipc/read/schema.rs +++ b/crates/polars-arrow/src/io/ipc/read/schema.rs @@ -1,6 +1,7 @@ use arrow_format::ipc::planus::ReadAsRoot; use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef}; use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use super::super::{IpcField, IpcSchema}; use super::{OutOfSpecKind, StreamMetadata}; @@ -28,14 +29,15 @@ fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Fi let extension = get_extension(&metadata); - let (data_type, ipc_field_) = get_data_type(ipc_field, extension, true)?; + let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?; let field = Field { - name: ipc_field - .name()? - .ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))? - .to_string(), - data_type, + name: PlSmallStr::from_str( + ipc_field + .name()? + .ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?, + ), + dtype, is_nullable: ipc_field.nullable()?, metadata, }; @@ -49,7 +51,7 @@ fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult for kv in list { let kv = kv?; if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) { - metadata_map.insert(k.to_string(), v.to_string()); + metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v)); } } metadata_map @@ -85,7 +87,7 @@ fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult< fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> { let unit = deserialize_timeunit(time.unit()?)?; - let data_type = match (time.bit_width()?, unit) { + let dtype = match (time.bit_width()?, unit) { (32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second), (32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond), (64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond), @@ -96,14 +98,14 @@ fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> { ) }, }; - Ok((data_type, IpcField::default())) + Ok((dtype, IpcField::default())) } fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> { - let timezone = timestamp.timezone()?.map(|tz| tz.to_string()); + let timezone = timestamp.timezone()?; let time_unit = deserialize_timeunit(timestamp.unit()?)?; Ok(( - ArrowDataType::Timestamp(time_unit, timezone), + ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)), IpcField::default(), )) } @@ -141,9 +143,9 @@ fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, .ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??; let (field, ipc_field) = deserialize_field(inner)?; - let data_type = ArrowDataType::Map(Box::new(field), is_sorted); + let dtype = ArrowDataType::Map(Box::new(field), is_sorted); Ok(( - data_type, + dtype, IpcField { fields: vec![ipc_field], dictionary_id: None, @@ -232,7 +234,7 @@ fn deserialize_fixed_size_list( } /// Get the Arrow data type from the flatbuffer Field table -fn get_data_type( +fn get_dtype( field: arrow_format::ipc::FieldRef, extension: Extension, may_be_dictionary: bool, @@ -243,7 +245,7 @@ fn get_data_type( .index_type()? .ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?; let index_type = deserialize_integer(int)?; - let (inner, mut ipc_field) = get_data_type(field, extension, false)?; + let (inner, mut ipc_field) = get_dtype(field, extension, false)?; ipc_field.dictionary_id = Some(dictionary.id()?); return Ok(( ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?), @@ -254,9 +256,9 @@ fn get_data_type( if let Some(extension) = extension { let (name, metadata) = extension; - let (data_type, fields) = get_data_type(field, None, false)?; + let (dtype, fields) = get_dtype(field, None, false)?; return Ok(( - ArrowDataType::Extension(name, Box::new(data_type), metadata), + ArrowDataType::Extension(name, Box::new(dtype), metadata), fields, )); } @@ -270,8 +272,8 @@ fn get_data_type( Null(_) => (ArrowDataType::Null, IpcField::default()), Bool(_) => (ArrowDataType::Boolean, IpcField::default()), Int(int) => { - let data_type = deserialize_integer(int)?.into(); - (data_type, IpcField::default()) + let dtype = deserialize_integer(int)?.into(); + (dtype, IpcField::default()) }, Binary(_) => (ArrowDataType::Binary, IpcField::default()), LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()), @@ -289,24 +291,24 @@ fn get_data_type( IpcField::default(), ), FloatingPoint(float) => { - let data_type = match float.precision()? { + let dtype = match float.precision()? { arrow_format::ipc::Precision::Half => ArrowDataType::Float16, arrow_format::ipc::Precision::Single => ArrowDataType::Float32, arrow_format::ipc::Precision::Double => ArrowDataType::Float64, }; - (data_type, IpcField::default()) + (dtype, IpcField::default()) }, Date(date) => { - let data_type = match date.unit()? { + let dtype = match date.unit()? { arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32, arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64, }; - (data_type, IpcField::default()) + (dtype, IpcField::default()) }, Time(time) => deserialize_time(time)?, Timestamp(timestamp) => deserialize_timestamp(timestamp)?, Interval(interval) => { - let data_type = match interval.unit()? { + let dtype = match interval.unit()? { arrow_format::ipc::IntervalUnit::YearMonth => { ArrowDataType::Interval(IntervalUnit::YearMonth) }, @@ -317,7 +319,7 @@ fn get_data_type( ArrowDataType::Interval(IntervalUnit::MonthDayNano) }, }; - (data_type, IpcField::default()) + (dtype, IpcField::default()) }, Duration(duration) => { let time_unit = deserialize_timeunit(duration.unit()?)?; @@ -337,13 +339,13 @@ fn get_data_type( .try_into() .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - let data_type = match bit_width { + let dtype = match bit_width { 128 => ArrowDataType::Decimal(precision, scale), 256 => ArrowDataType::Decimal256(precision, scale), _ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)), }; - (data_type, IpcField::default()) + (dtype, IpcField::default()) }, List(_) => deserialize_list(field)?, LargeList(_) => deserialize_large_list(field)?, @@ -379,32 +381,23 @@ pub(super) fn fb_to_schema( let fields = schema .fields()? .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?; - let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { - let (field, fields) = deserialize_field(field?)?; - Ok((field, fields)) - }))?; + + let mut arrow_schema = ArrowSchema::with_capacity(fields.len()); + let mut ipc_fields = Vec::with_capacity(fields.len()); + + for field in fields { + let (field, ipc_field) = deserialize_field(field?)?; + arrow_schema.insert(field.name.clone(), field); + ipc_fields.push(ipc_field); + } let is_little_endian = match schema.endianness()? { arrow_format::ipc::Endianness::Little => true, arrow_format::ipc::Endianness::Big => false, }; - let mut metadata = Metadata::default(); - if let Some(md_fields) = schema.custom_metadata()? { - for kv in md_fields { - let kv = kv?; - let k_str = kv.key()?; - let v_str = kv.value()?; - if let Some(k) = k_str { - if let Some(v) = v_str { - metadata.insert(k.to_string(), v.to_string()); - } - } - } - } - Ok(( - ArrowSchema { fields, metadata }, + arrow_schema, IpcSchema { fields: ipc_fields, is_little_endian, diff --git a/crates/polars-arrow/src/io/ipc/read/stream.rs b/crates/polars-arrow/src/io/ipc/read/stream.rs index a7acefe59090..87241596cdbe 100644 --- a/crates/polars-arrow/src/io/ipc/read/stream.rs +++ b/crates/polars-arrow/src/io/ipc/read/stream.rs @@ -1,8 +1,8 @@ use std::io::Read; -use ahash::AHashMap; use arrow_format::ipc::planus::ReadAsRoot; use polars_error::{polars_bail, polars_err, PolarsError, PolarsResult}; +use polars_utils::aliases::PlHashMap; use super::super::CONTINUATION_MARKER; use super::common::*; @@ -93,7 +93,7 @@ fn read_next( dictionaries: &mut Dictionaries, message_buffer: &mut Vec, data_buffer: &mut Vec, - projection: &Option<(Vec, AHashMap, ArrowSchema)>, + projection: &Option<(Vec, PlHashMap, ArrowSchema)>, scratch: &mut Vec, ) -> PolarsResult> { // determine metadata length @@ -167,7 +167,7 @@ fn read_next( let chunk = read_record_batch( batch, - &metadata.schema.fields, + &metadata.schema, &metadata.ipc_schema, projection.as_ref().map(|x| x.0.as_ref()), None, @@ -201,7 +201,7 @@ fn read_next( read_dictionary( batch, - &metadata.schema.fields, + &metadata.schema, &metadata.ipc_schema, dictionaries, &mut dict_reader, @@ -238,7 +238,7 @@ pub struct StreamReader { finished: bool, data_buffer: Vec, message_buffer: Vec, - projection: Option<(Vec, AHashMap, ArrowSchema)>, + projection: Option<(Vec, PlHashMap, ArrowSchema)>, scratch: Vec, } @@ -250,11 +250,7 @@ impl StreamReader { /// To check if the reader is done, use `is_finished(self)` pub fn new(reader: R, metadata: StreamMetadata, projection: Option>) -> Self { let projection = projection.map(|projection| { - let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); - let schema = ArrowSchema { - fields, - metadata: metadata.schema.metadata.clone(), - }; + let (p, h, schema) = prepare_projection(&metadata.schema, projection); (p, h, schema) }); diff --git a/crates/polars-arrow/src/io/ipc/read/stream_async.rs b/crates/polars-arrow/src/io/ipc/read/stream_async.rs index 8d66f81793ed..ab29550d8a14 100644 --- a/crates/polars-arrow/src/io/ipc/read/stream_async.rs +++ b/crates/polars-arrow/src/io/ipc/read/stream_async.rs @@ -132,9 +132,9 @@ async fn maybe_next( .read_to_end(&mut state.data_buffer) .await?; - read_record_batch( + let chunk = read_record_batch( batch, - &state.metadata.schema.fields, + &state.metadata.schema, &state.metadata.ipc_schema, None, None, @@ -144,8 +144,9 @@ async fn maybe_next( 0, state.data_buffer.len() as u64, &mut scratch, - ) - .map(|chunk| Some(StreamState::Some((state, chunk)))) + )?; + + Ok(Some(StreamState::Some((state, chunk)))) }, arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { state.data_buffer.clear(); @@ -161,7 +162,7 @@ async fn maybe_next( read_dictionary( batch, - &state.metadata.schema.fields, + &state.metadata.schema, &state.metadata.ipc_schema, &mut state.dictionaries, &mut dict_reader, diff --git a/crates/polars-arrow/src/io/ipc/write/common.rs b/crates/polars-arrow/src/io/ipc/write/common.rs index 30312bf7f19d..2aebf1ec5d50 100644 --- a/crates/polars-arrow/src/io/ipc/write/common.rs +++ b/crates/polars-arrow/src/io/ipc/write/common.rs @@ -38,7 +38,7 @@ fn encode_dictionary( encoded_dictionaries: &mut Vec, ) -> PolarsResult<()> { use PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null | FixedSizeBinary | BinaryView | Utf8View => Ok(()), Dictionary(key_type) => match_integer_type!(key_type, |$T| { @@ -231,7 +231,7 @@ fn serialize_compression( } fn set_variadic_buffer_counts(counts: &mut Vec, array: &dyn Array) { - match array.data_type() { + match array.dtype() { ArrowDataType::Utf8View => { let array = array.as_any().downcast_ref::().unwrap(); counts.push(array.data_buffers().len() as i64); @@ -297,7 +297,7 @@ fn chunk_to_bytes_amortized( let mut variadic_buffer_counts = vec![]; for array in chunk.arrays() { // We don't want to write all buffers in sliced arrays. - let array = match array.data_type() { + let array = match array.dtype() { ArrowDataType::BinaryView => { let concrete_arr = array.as_any().downcast_ref::().unwrap(); gc_bin_view(array, concrete_arr) @@ -432,7 +432,7 @@ impl DictionaryTracker { /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just /// inserted. pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult { - let values = match array.data_type() { + let values = match array.dtype() { ArrowDataType::Dictionary(key_type, _, _) => { match_integer_type!(key_type, |$T| { let array = array diff --git a/crates/polars-arrow/src/io/ipc/write/file_async.rs b/crates/polars-arrow/src/io/ipc/write/file_async.rs index 142e18b71cbb..aaae101785bc 100644 --- a/crates/polars-arrow/src/io/ipc/write/file_async.rs +++ b/crates/polars-arrow/src/io/ipc/write/file_async.rs @@ -44,7 +44,7 @@ where ipc_fields: Option>, options: WriteOptions, ) -> Self { - let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(schema.iter_values())); let encoded = EncodedData { ipc_message: schema_to_bytes(&schema, &fields), arrow_data: vec![], diff --git a/crates/polars-arrow/src/io/ipc/write/mod.rs b/crates/polars-arrow/src/io/ipc/write/mod.rs index e272f6e1c2c5..d8afc1571721 100644 --- a/crates/polars-arrow/src/io/ipc/write/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/mod.rs @@ -27,28 +27,28 @@ pub mod file_async; use super::IpcField; use crate::datatypes::{ArrowDataType, Field}; -fn default_ipc_field(data_type: &ArrowDataType, current_id: &mut i64) -> IpcField { +fn default_ipc_field(dtype: &ArrowDataType, current_id: &mut i64) -> IpcField { use crate::datatypes::ArrowDataType::*; - match data_type.to_logical_type() { + match dtype.to_logical_type() { // single child => recurse Map(inner, ..) | FixedSizeList(inner, _) | LargeList(inner) | List(inner) => IpcField { - fields: vec![default_ipc_field(inner.data_type(), current_id)], + fields: vec![default_ipc_field(inner.dtype(), current_id)], dictionary_id: None, }, // multiple children => recurse Union(fields, ..) | Struct(fields) => IpcField { fields: fields .iter() - .map(|f| default_ipc_field(f.data_type(), current_id)) + .map(|f| default_ipc_field(f.dtype(), current_id)) .collect(), dictionary_id: None, }, // dictionary => current_id - Dictionary(_, data_type, _) => { + Dictionary(_, dtype, _) => { let dictionary_id = Some(*current_id); *current_id += 1; IpcField { - fields: vec![default_ipc_field(data_type, current_id)], + fields: vec![default_ipc_field(dtype, current_id)], dictionary_id, } }, @@ -61,10 +61,9 @@ fn default_ipc_field(data_type: &ArrowDataType, current_id: &mut i64) -> IpcFiel } /// Assigns every dictionary field a unique ID -pub fn default_ipc_fields(fields: &[Field]) -> Vec { +pub fn default_ipc_fields<'a>(fields: impl ExactSizeIterator) -> Vec { let mut dictionary_id = 0i64; fields - .iter() - .map(|field| default_ipc_field(field.data_type().to_logical_type(), &mut dictionary_id)) + .map(|field| default_ipc_field(field.dtype().to_logical_type(), &mut dictionary_id)) .collect() } diff --git a/crates/polars-arrow/src/io/ipc/write/schema.rs b/crates/polars-arrow/src/io/ipc/write/schema.rs index 8243e07a7d04..e8ed25c5c77e 100644 --- a/crates/polars-arrow/src/io/ipc/write/schema.rs +++ b/crates/polars-arrow/src/io/ipc/write/schema.rs @@ -32,19 +32,12 @@ pub fn serialize_schema( }; let fields = schema - .fields - .iter() + .iter_values() .zip(ipc_fields.iter()) .map(|(field, ipc_field)| serialize_field(field, ipc_field)) .collect::>(); - let custom_metadata = schema - .metadata - .iter() - .map(|(k, v)| key_value(k, v)) - .collect::>(); - - let custom_metadata = (!custom_metadata.is_empty()).then_some(custom_metadata); + let custom_metadata = None; arrow_format::ipc::Schema { endianness, @@ -63,50 +56,58 @@ fn key_value(key: impl Into, val: impl Into) -> arrow_format::ip fn write_metadata(metadata: &Metadata, kv_vec: &mut Vec) { for (k, v) in metadata { - if k != "ARROW:extension:name" && k != "ARROW:extension:metadata" { - kv_vec.push(key_value(k, v)); + if k.as_str() != "ARROW:extension:name" && k.as_str() != "ARROW:extension:metadata" { + kv_vec.push(key_value(k.clone().into_string(), v.clone().into_string())); } } } fn write_extension( name: &str, - metadata: &Option, + metadata: Option<&str>, kv_vec: &mut Vec, ) { if let Some(metadata) = metadata { - kv_vec.push(key_value("ARROW:extension:metadata", metadata)); + kv_vec.push(key_value("ARROW:extension:metadata".to_string(), metadata)); } - kv_vec.push(key_value("ARROW:extension:name", name)); + kv_vec.push(key_value("ARROW:extension:name".to_string(), name)); } /// Create an IPC Field from an Arrow Field pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_format::ipc::Field { // custom metadata. let mut kv_vec = vec![]; - if let ArrowDataType::Extension(name, _, metadata) = field.data_type() { - write_extension(name, metadata, &mut kv_vec); + if let ArrowDataType::Extension(name, _, metadata) = field.dtype() { + write_extension( + name.as_str(), + metadata.as_ref().map(|x| x.as_str()), + &mut kv_vec, + ); } - let type_ = serialize_type(field.data_type()); - let children = serialize_children(field.data_type(), ipc_field); + let type_ = serialize_type(field.dtype()); + let children = serialize_children(field.dtype(), ipc_field); - let dictionary = - if let ArrowDataType::Dictionary(index_type, inner, is_ordered) = field.data_type() { - if let ArrowDataType::Extension(name, _, metadata) = inner.as_ref() { - write_extension(name, metadata, &mut kv_vec); - } - Some(serialize_dictionary( - index_type, - ipc_field - .dictionary_id - .expect("All Dictionary types have `dict_id`"), - *is_ordered, - )) - } else { - None - }; + let dictionary = if let ArrowDataType::Dictionary(index_type, inner, is_ordered) = field.dtype() + { + if let ArrowDataType::Extension(name, _, metadata) = inner.as_ref() { + write_extension( + name.as_str(), + metadata.as_ref().map(|x| x.as_str()), + &mut kv_vec, + ); + } + Some(serialize_dictionary( + index_type, + ipc_field + .dictionary_id + .expect("All Dictionary types have `dict_id`"), + *is_ordered, + )) + } else { + None + }; write_metadata(&field.metadata, &mut kv_vec); @@ -117,7 +118,7 @@ pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_form }; arrow_format::ipc::Field { - name: Some(field.name.clone()), + name: Some(field.name.to_string()), nullable: field.is_nullable, type_: Some(type_), dictionary: dictionary.map(Box::new), @@ -135,10 +136,10 @@ fn serialize_time_unit(unit: &TimeUnit) -> arrow_format::ipc::TimeUnit { } } -fn serialize_type(data_type: &ArrowDataType) -> arrow_format::ipc::Type { +fn serialize_type(dtype: &ArrowDataType) -> arrow_format::ipc::Type { use arrow_format::ipc; use ArrowDataType::*; - match data_type { + match dtype { Null => ipc::Type::Null(Box::new(ipc::Null {})), Boolean => ipc::Type::Bool(Box::new(ipc::Bool {})), UInt8 => ipc::Type::Int(Box::new(ipc::Int { @@ -218,7 +219,7 @@ fn serialize_type(data_type: &ArrowDataType) -> arrow_format::ipc::Type { })), Timestamp(unit, tz) => ipc::Type::Timestamp(Box::new(ipc::Timestamp { unit: serialize_time_unit(unit), - timezone: tz.as_ref().cloned(), + timezone: tz.as_ref().map(|x| x.to_string()), })), Interval(unit) => ipc::Type::Interval(Box::new(ipc::Interval { unit: match unit { @@ -252,11 +253,11 @@ fn serialize_type(data_type: &ArrowDataType) -> arrow_format::ipc::Type { } fn serialize_children( - data_type: &ArrowDataType, + dtype: &ArrowDataType, ipc_field: &IpcField, ) -> Vec { use ArrowDataType::*; - match data_type { + match dtype { Null | Boolean | Int8 diff --git a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs index b33f50b2277a..f13098477d4d 100644 --- a/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs +++ b/crates/polars-arrow/src/io/ipc/write/serialize/mod.rs @@ -50,7 +50,7 @@ pub fn write( null_count: array.null_count() as i64, }); use PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => (), Boolean => write_boolean( array.as_any().downcast_ref().unwrap(), diff --git a/crates/polars-arrow/src/io/ipc/write/stream.rs b/crates/polars-arrow/src/io/ipc/write/stream.rs index 5122fd848c7a..330b35d4ca4b 100644 --- a/crates/polars-arrow/src/io/ipc/write/stream.rs +++ b/crates/polars-arrow/src/io/ipc/write/stream.rs @@ -59,7 +59,7 @@ impl StreamWriter { self.ipc_fields = Some(if let Some(ipc_fields) = ipc_fields { ipc_fields } else { - default_ipc_fields(&schema.fields) + default_ipc_fields(schema.iter_values()) }); let encoded_message = EncodedData { diff --git a/crates/polars-arrow/src/io/ipc/write/stream_async.rs b/crates/polars-arrow/src/io/ipc/write/stream_async.rs index 9858739134c2..3718d6f82b29 100644 --- a/crates/polars-arrow/src/io/ipc/write/stream_async.rs +++ b/crates/polars-arrow/src/io/ipc/write/stream_async.rs @@ -36,7 +36,7 @@ where ipc_fields: Option>, write_options: WriteOptions, ) -> Self { - let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(schema.iter_values())); let task = Some(Self::start(writer, schema, &fields[..])); Self { writer: None, diff --git a/crates/polars-arrow/src/io/ipc/write/writer.rs b/crates/polars-arrow/src/io/ipc/write/writer.rs index 361a40bf5e06..ec010d4d0180 100644 --- a/crates/polars-arrow/src/io/ipc/write/writer.rs +++ b/crates/polars-arrow/src/io/ipc/write/writer.rs @@ -66,7 +66,7 @@ impl FileWriter { let ipc_fields = if let Some(ipc_fields) = ipc_fields { ipc_fields } else { - default_ipc_fields(&schema.fields) + default_ipc_fields(schema.iter_values()) }; Self { diff --git a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs index 31bc5880c68a..99382b0b6407 100644 --- a/crates/polars-arrow/src/legacy/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/array/fixed_size_list.rs @@ -51,12 +51,12 @@ impl AnonymousBuilder { } pub fn finish(self, inner_dtype: Option<&ArrowDataType>) -> PolarsResult { - let mut inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].data_type()); + let mut inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].dtype()); if is_nested_null(inner_dtype) { for arr in &self.arrays { - if !is_nested_null(arr.data_type()) { - inner_dtype = arr.data_type(); + if !is_nested_null(arr.dtype()) { + inner_dtype = arr.dtype(); break; } } @@ -67,9 +67,9 @@ impl AnonymousBuilder { .arrays .iter() .map(|arr| { - if matches!(arr.data_type(), ArrowDataType::Null) { + if matches!(arr.dtype(), ArrowDataType::Null) { new_null_array(inner_dtype.clone(), arr.len()) - } else if is_nested_null(arr.data_type()) { + } else if is_nested_null(arr.dtype()) { convert_inner_type(&**arr, inner_dtype) } else { arr.to_boxed() @@ -79,9 +79,9 @@ impl AnonymousBuilder { let values = concatenate_owned_unchecked(&arrays)?; - let data_type = FixedSizeListArray::default_datatype(inner_dtype.clone(), self.width); + let dtype = FixedSizeListArray::default_datatype(inner_dtype.clone(), self.width); Ok(FixedSizeListArray::new( - data_type, + dtype, values, self.validity.map(|validity| validity.into()), )) diff --git a/crates/polars-arrow/src/legacy/array/list.rs b/crates/polars-arrow/src/legacy/array/list.rs index ff02011663cb..3e3118f25248 100644 --- a/crates/polars-arrow/src/legacy/array/list.rs +++ b/crates/polars-arrow/src/legacy/array/list.rs @@ -118,7 +118,7 @@ impl<'a> AnonymousBuilder<'a> { }, } } else { - let inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].data_type()); + let inner_dtype = inner_dtype.unwrap_or_else(|| self.arrays[0].dtype()); // check if there is a dtype that is not `Null` // if we find it, we will convert the null arrays @@ -126,8 +126,8 @@ impl<'a> AnonymousBuilder<'a> { let mut non_null_dtype = None; if is_nested_null(inner_dtype) { for arr in &self.arrays { - if !is_nested_null(arr.data_type()) { - non_null_dtype = Some(arr.data_type()); + if !is_nested_null(arr.dtype()) { + non_null_dtype = Some(arr.dtype()); break; } } @@ -139,7 +139,7 @@ impl<'a> AnonymousBuilder<'a> { .arrays .iter() .map(|arr| { - if is_nested_null(arr.data_type()) { + if is_nested_null(arr.dtype()) { convert_inner_type(&**arr, dtype) } else { arr.to_boxed() diff --git a/crates/polars-arrow/src/legacy/array/mod.rs b/crates/polars-arrow/src/legacy/array/mod.rs index 18e5a386df0e..bbb876283470 100644 --- a/crates/polars-arrow/src/legacy/array/mod.rs +++ b/crates/polars-arrow/src/legacy/array/mod.rs @@ -50,7 +50,7 @@ pub trait ListFromIter { /// Will produce incorrect arrays if size hint is incorrect. unsafe fn from_iter_primitive_trusted_len( iter: I, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> ListArray where T: NativeType, @@ -70,9 +70,9 @@ pub trait ListFromIter { // SAFETY: // offsets are monotonically increasing ListArray::new( - ListArray::::default_datatype(data_type.clone()), + ListArray::::default_datatype(dtype.clone()), Offsets::new_unchecked(offsets).into(), - Box::new(values.to(data_type)), + Box::new(values.to(dtype)), Some(validity.into()), ) } @@ -185,14 +185,12 @@ pub trait ListFromIter { } impl ListFromIter for ListArray {} -fn is_nested_null(data_type: &ArrowDataType) -> bool { - match data_type { +fn is_nested_null(dtype: &ArrowDataType) -> bool { + match dtype { ArrowDataType::Null => true, - ArrowDataType::LargeList(field) => is_nested_null(field.data_type()), - ArrowDataType::FixedSizeList(field, _) => is_nested_null(field.data_type()), - ArrowDataType::Struct(fields) => { - fields.iter().all(|field| is_nested_null(field.data_type())) - }, + ArrowDataType::LargeList(field) => is_nested_null(field.dtype()), + ArrowDataType::FixedSizeList(field, _) => is_nested_null(field.dtype()), + ArrowDataType::Struct(fields) => fields.iter().all(|field| is_nested_null(field.dtype())), _ => false, } } @@ -203,8 +201,8 @@ pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { let array = array.as_any().downcast_ref::().unwrap(); let inner = array.values(); - let new_values = convert_inner_type(inner.as_ref(), field.data_type()); - let dtype = LargeListArray::default_datatype(new_values.data_type().clone()); + let new_values = convert_inner_type(inner.as_ref(), field.dtype()); + let dtype = LargeListArray::default_datatype(new_values.dtype().clone()); LargeListArray::new( dtype, array.offsets().clone(), @@ -216,9 +214,8 @@ pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box { let array = array.as_any().downcast_ref::().unwrap(); let inner = array.values(); - let new_values = convert_inner_type(inner.as_ref(), field.data_type()); - let dtype = - FixedSizeListArray::default_datatype(new_values.data_type().clone(), *width); + let new_values = convert_inner_type(inner.as_ref(), field.dtype()); + let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), *width); FixedSizeListArray::new(dtype, new_values, array.validity().cloned()).boxed() }, ArrowDataType::Struct(fields) => { @@ -227,7 +224,7 @@ pub fn convert_inner_type(array: &dyn Array, dtype: &ArrowDataType) -> Box>(); StructArray::new(dtype.clone(), new_values, array.validity().cloned()).boxed() }, diff --git a/crates/polars-arrow/src/legacy/array/null.rs b/crates/polars-arrow/src/legacy/array/null.rs index 6a802540db83..ec630250e0c5 100644 --- a/crates/polars-arrow/src/legacy/array/null.rs +++ b/crates/polars-arrow/src/legacy/array/null.rs @@ -10,7 +10,7 @@ pub struct MutableNullArray { } impl MutableArray for MutableNullArray { - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { &ArrowDataType::Null } diff --git a/crates/polars-arrow/src/legacy/kernels/atan2.rs b/crates/polars-arrow/src/legacy/kernels/atan2.rs index 7884ab18d09f..40d3d527b24a 100644 --- a/crates/polars-arrow/src/legacy/kernels/atan2.rs +++ b/crates/polars-arrow/src/legacy/kernels/atan2.rs @@ -8,5 +8,5 @@ pub fn atan2(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> Primiti where T: Float + NativeType, { - binary(arr_1, arr_2, arr_1.data_type().clone(), |a, b| a.atan2(b)) + binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| a.atan2(b)) } diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index 2b21d872c948..4f3f332dac28 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -139,7 +139,7 @@ pub fn array_to_unit_list(array: ArrayRef) -> ListArray { // offsets are monotonically increasing unsafe { let offsets: OffsetsBuffer = Offsets::new_unchecked(offsets).into(); - let dtype = ListArray::::default_datatype(array.data_type().clone()); + let dtype = ListArray::::default_datatype(array.dtype().clone()); ListArray::::new(dtype, offsets, array, None) } } diff --git a/crates/polars-arrow/src/legacy/kernels/pow.rs b/crates/polars-arrow/src/legacy/kernels/pow.rs index 35ab21bcf7f8..a790e6193129 100644 --- a/crates/polars-arrow/src/legacy/kernels/pow.rs +++ b/crates/polars-arrow/src/legacy/kernels/pow.rs @@ -9,7 +9,5 @@ where T: Pow + NativeType, F: NativeType, { - binary(arr_1, arr_2, arr_1.data_type().clone(), |a, b| { - Pow::pow(a, b) - }) + binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| Pow::pow(a, b)) } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs index 0588dec64ef9..40a464e6f5bc 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/quantile_filter.rs @@ -6,7 +6,6 @@ use std::ops::{Add, Div, Mul, Sub}; use num_traits::NumCast; use polars_utils::index::{Bounded, Indexable, NullCount}; -use polars_utils::iter::IntoIteratorCopied; use polars_utils::nulls::IsNull; use polars_utils::slice::{GetSaferUnchecked, SliceAble}; use polars_utils::sort::arg_sort_ascending; @@ -86,11 +85,7 @@ where impl<'a, A> Block<'a, A> where - A: Indexable - + Bounded - + NullCount - + IntoIteratorCopied::Item> - + Clone, + A: Indexable + Bounded + NullCount + Clone, ::Item: TotalOrd + Copy + IsNull + Debug + 'a, { fn new( @@ -101,11 +96,7 @@ where ) -> Self { debug_assert!(!alpha.is_empty()); let k = alpha.len(); - let pi = arg_sort_ascending( - ::into_iter(alpha.clone()), - scratch, - alpha.len(), - ); + let pi = arg_sort_ascending((0..alpha.len()).map(|i| alpha.get(i)), scratch, alpha.len()); let nulls_in_window = alpha.null_count(); let m_index = k / 2; @@ -384,11 +375,7 @@ trait LenGet { impl<'a, A> LenGet for &mut Block<'a, A> where - A: Indexable - + Bounded - + NullCount - + IntoIteratorCopied::Item> - + Clone, + A: Indexable + Bounded + NullCount + Clone, ::Item: Copy + TotalOrd + Debug + 'a, { type Item = ::Item; @@ -418,11 +405,7 @@ where impl<'a, A> BlockUnion<'a, A> where - A: Indexable - + Bounded - + NullCount - + IntoIteratorCopied::Item> - + Clone, + A: Indexable + Bounded + NullCount + Clone, ::Item: TotalOrd + Copy + Debug, { fn new(block_left: &'a mut Block<'a, A>, block_right: &'a mut Block<'a, A>) -> Self { @@ -462,11 +445,7 @@ where impl<'a, A> LenGet for BlockUnion<'a, A> where - A: Indexable - + Bounded - + NullCount - + IntoIteratorCopied::Item> - + Clone, + A: Indexable + Bounded + NullCount + Clone, ::Item: TotalOrd + Copy + Debug, { type Item = ::Item; @@ -679,12 +658,7 @@ pub(super) fn rolling_quantile::Item>>( quantile: f64, ) -> Out where - A: Indexable - + SliceAble - + Bounded - + NullCount - + IntoIteratorCopied::Item> - + Clone, + A: Indexable + SliceAble + Bounded + NullCount + Clone, ::Item: Default + TotalOrd + Copy + FinishLinear + Debug, { let mut scratch_left = vec![]; diff --git a/crates/polars-arrow/src/legacy/kernels/set.rs b/crates/polars-arrow/src/legacy/kernels/set.rs index 41f3dbcf5c3d..338fc4b1a17c 100644 --- a/crates/polars-arrow/src/legacy/kernels/set.rs +++ b/crates/polars-arrow/src/legacy/kernels/set.rs @@ -34,7 +34,7 @@ where } }); - PrimitiveArray::new(array.data_type().clone(), av.into(), None) + PrimitiveArray::new(array.dtype().clone(), av.into(), None) } /// Set values in a primitive array based on a mask array. This is fast when large chunks of bits are set or unset. @@ -42,7 +42,7 @@ pub fn set_with_mask( array: &PrimitiveArray, mask: &BooleanArray, value: T, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PrimitiveArray { let values = array.values(); @@ -61,7 +61,7 @@ pub fn set_with_mask( valid.bitor(mask_bitmap) }); - PrimitiveArray::new(data_type, buf.into(), validity) + PrimitiveArray::new(dtype, buf.into(), validity) } /// Efficiently sets value at the indices from the iterator to `set_value`. @@ -70,7 +70,7 @@ pub fn scatter_single_non_null( array: &PrimitiveArray, idx: I, set_value: T, - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PolarsResult> where T: NativeType, @@ -89,7 +89,7 @@ where })?; Ok(PrimitiveArray::new( - data_type, + dtype, buf.into(), array.validity().cloned(), )) diff --git a/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs b/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs index 8fd54d712e94..62e2ba1353f2 100644 --- a/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs +++ b/crates/polars-arrow/src/legacy/kernels/take_agg/var.rs @@ -1,6 +1,7 @@ use super::*; -/// Numerical stable online variance aggregation +/// Numerical stable online variance aggregation. +/// /// See: /// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". /// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. diff --git a/crates/polars-arrow/src/legacy/utils.rs b/crates/polars-arrow/src/legacy/utils.rs index af482171a1a9..2469e51957cb 100644 --- a/crates/polars-arrow/src/legacy/utils.rs +++ b/crates/polars-arrow/src/legacy/utils.rs @@ -1,5 +1,3 @@ -use std::borrow::Borrow; - use crate::array::PrimitiveArray; use crate::bitmap::utils::set_bit_unchecked; use crate::bitmap::MutableBitmap; @@ -34,40 +32,6 @@ pub trait CustomIterTools: Iterator { { FromIteratorReversed::from_trusted_len_iter_rev(self) } - - fn all_equal(&mut self) -> bool - where - Self: Sized, - Self::Item: PartialEq, - { - match self.next() { - None => true, - Some(a) => self.all(|x| a == x), - } - } - - fn fold_options(&mut self, mut start: B, mut f: F) -> Option - where - Self: Iterator>, - F: FnMut(B, A) -> B, - { - for elt in self { - match elt { - Some(v) => start = f(start, v), - None => return None, - } - } - Some(start) - } - - fn contains(&mut self, query: &Q) -> bool - where - Self: Sized, - Self::Item: Borrow, - Q: PartialEq, - { - self.any(|x| x.borrow() == query) - } } pub trait CustomIterToolsSized: Iterator + Sized {} diff --git a/crates/polars-arrow/src/lib.rs b/crates/polars-arrow/src/lib.rs index 7de0d037ece1..15af97483a41 100644 --- a/crates/polars-arrow/src/lib.rs +++ b/crates/polars-arrow/src/lib.rs @@ -42,5 +42,4 @@ pub mod util; // re-exported because we return `Either` in our public API // re-exported to construct dictionaries -pub use ahash::AHashMap; pub use either::Either; diff --git a/crates/polars-arrow/src/mmap/array.rs b/crates/polars-arrow/src/mmap/array.rs index d22705c63b63..8822824858c6 100644 --- a/crates/polars-arrow/src/mmap/array.rs +++ b/crates/polars-arrow/src/mmap/array.rs @@ -187,9 +187,9 @@ fn mmap_fixed_size_binary>( node: &Node, block_offset: usize, buffers: &mut VecDeque, - data_type: &ArrowDataType, + dtype: &ArrowDataType, ) -> PolarsResult { - let bytes_per_row = if let ArrowDataType::FixedSizeBinary(bytes_per_row) = data_type { + let bytes_per_row = if let ArrowDataType::FixedSizeBinary(bytes_per_row) = dtype { bytes_per_row } else { polars_bail!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidDataType); @@ -337,14 +337,14 @@ fn mmap_list>( data: Arc, node: &Node, block_offset: usize, - data_type: &ArrowDataType, + dtype: &ArrowDataType, ipc_field: &IpcField, dictionaries: &Dictionaries, field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, buffers: &mut VecDeque, ) -> PolarsResult { - let child = ListArray::::try_get_child(data_type)?.data_type(); + let child = ListArray::::try_get_child(dtype)?.dtype(); let (num_rows, null_count) = get_num_rows_and_null_count(node)?; let data_ref = data.as_ref().as_ref(); @@ -383,16 +383,14 @@ fn mmap_fixed_size_list>( data: Arc, node: &Node, block_offset: usize, - data_type: &ArrowDataType, + dtype: &ArrowDataType, ipc_field: &IpcField, dictionaries: &Dictionaries, field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, buffers: &mut VecDeque, ) -> PolarsResult { - let child = FixedSizeListArray::try_child_and_size(data_type)? - .0 - .data_type(); + let child = FixedSizeListArray::try_child_and_size(dtype)?.0.dtype(); let (num_rows, null_count) = get_num_rows_and_null_count(node)?; let data_ref = data.as_ref().as_ref(); @@ -428,14 +426,14 @@ fn mmap_struct>( data: Arc, node: &Node, block_offset: usize, - data_type: &ArrowDataType, + dtype: &ArrowDataType, ipc_field: &IpcField, dictionaries: &Dictionaries, field_nodes: &mut VecDeque, variadic_buffer_counts: &mut VecDeque, buffers: &mut VecDeque, ) -> PolarsResult { - let children = StructArray::try_get_fields(data_type)?; + let children = StructArray::try_get_fields(dtype)?; let (num_rows, null_count) = get_num_rows_and_null_count(node)?; let data_ref = data.as_ref().as_ref(); @@ -444,7 +442,7 @@ fn mmap_struct>( let values = children .iter() - .map(|f| &f.data_type) + .map(|f| &f.dtype) .zip(ipc_field.fields.iter()) .map(|(child, ipc)| { get_array( @@ -514,7 +512,7 @@ fn mmap_dict>( fn get_array>( data: Arc, block_offset: usize, - data_type: &ArrowDataType, + dtype: &ArrowDataType, ipc_field: &IpcField, dictionaries: &Dictionaries, field_nodes: &mut VecDeque, @@ -526,7 +524,7 @@ fn get_array>( || polars_err!(ComputeError: "out-of-spec: {:?}", OutOfSpecKind::ExpectedBuffer), )?; - match data_type.to_physical_type() { + match dtype.to_physical_type() { Null => mmap_null(data, &node, block_offset, buffers), Boolean => mmap_boolean(data, &node, block_offset, buffers), Primitive(p) => with_match_primitive_type_full!(p, |$T| { @@ -536,13 +534,13 @@ fn get_array>( Utf8View | BinaryView => { mmap_binview(data, &node, block_offset, buffers, variadic_buffer_counts) }, - FixedSizeBinary => mmap_fixed_size_binary(data, &node, block_offset, buffers, data_type), + FixedSizeBinary => mmap_fixed_size_binary(data, &node, block_offset, buffers, dtype), LargeBinary | LargeUtf8 => mmap_binary::(data, &node, block_offset, buffers), List => mmap_list::( data, &node, block_offset, - data_type, + dtype, ipc_field, dictionaries, field_nodes, @@ -553,7 +551,7 @@ fn get_array>( data, &node, block_offset, - data_type, + dtype, ipc_field, dictionaries, field_nodes, @@ -564,7 +562,7 @@ fn get_array>( data, &node, block_offset, - data_type, + dtype, ipc_field, dictionaries, field_nodes, @@ -575,7 +573,7 @@ fn get_array>( data, &node, block_offset, - data_type, + dtype, ipc_field, dictionaries, field_nodes, @@ -587,7 +585,7 @@ fn get_array>( data, &node, block_offset, - data_type, + dtype, ipc_field, dictionaries, field_nodes, @@ -603,7 +601,7 @@ fn get_array>( pub(crate) unsafe fn mmap>( data: Arc, block_offset: usize, - data_type: ArrowDataType, + dtype: ArrowDataType, ipc_field: &IpcField, dictionaries: &Dictionaries, field_nodes: &mut VecDeque, @@ -613,7 +611,7 @@ pub(crate) unsafe fn mmap>( let array = get_array( data, block_offset, - &data_type, + &dtype, ipc_field, dictionaries, field_nodes, @@ -622,5 +620,5 @@ pub(crate) unsafe fn mmap>( )?; // The unsafety comes from the fact that `array` is not necessarily valid - // the IPC file may be corrupted (e.g. invalid offsets or non-utf8 data) - unsafe { try_from(InternalArrowArray::new(array, data_type)) } + unsafe { try_from(InternalArrowArray::new(array, dtype)) } } diff --git a/crates/polars-arrow/src/mmap/mod.rs b/crates/polars-arrow/src/mmap/mod.rs index bda45655d08c..b934c31de563 100644 --- a/crates/polars-arrow/src/mmap/mod.rs +++ b/crates/polars-arrow/src/mmap/mod.rs @@ -7,9 +7,10 @@ mod array; use arrow_format::ipc::planus::ReadAsRoot; use arrow_format::ipc::{Block, MessageRef, RecordBatchRef}; use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use crate::array::Array; -use crate::datatypes::{ArrowDataType, Field}; +use crate::datatypes::{ArrowDataType, ArrowSchema, Field}; use crate::io::ipc::read::file::{get_dictionary_batch, get_record_batch}; use crate::io::ipc::read::{ first_dict_field, Dictionaries, FileMetadata, IpcBuffer, Node, OutOfSpecKind, @@ -71,7 +72,7 @@ fn get_buffers_nodes(batch: RecordBatchRef) -> PolarsResult<(VecDeque } unsafe fn _mmap_record>( - fields: &[Field], + fields: &ArrowSchema, ipc_fields: &[IpcField], data: Arc, batch: RecordBatchRef, @@ -86,15 +87,15 @@ unsafe fn _mmap_record>( .unwrap_or_else(VecDeque::new); fields - .iter() - .map(|f| &f.data_type) + .iter_values() + .map(|f| &f.dtype) .cloned() .zip(ipc_fields) - .map(|(data_type, ipc_field)| { + .map(|(dtype, ipc_field)| { array::mmap( data.clone(), offset, - data_type, + dtype, ipc_field, dictionaries, &mut field_nodes, @@ -107,7 +108,7 @@ unsafe fn _mmap_record>( } unsafe fn _mmap_unchecked>( - fields: &[Field], + fields: &ArrowSchema, ipc_fields: &[IpcField], data: Arc, block: Block, @@ -147,7 +148,7 @@ pub unsafe fn mmap_unchecked>( let (message, offset) = read_message(data.as_ref().as_ref(), block)?; let batch = get_record_batch(message)?; _mmap_record( - &metadata.schema.fields, + &metadata.schema, &metadata.ipc_schema.fields, data.clone(), batch, @@ -169,7 +170,7 @@ unsafe fn mmap_dictionary>( .id() .map_err(|err| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::InvalidFlatbufferId(err)))?; let (first_field, first_ipc_field) = - first_dict_field(id, &metadata.schema.fields, &metadata.ipc_schema.fields)?; + first_dict_field(id, &metadata.schema, &metadata.ipc_schema.fields)?; let batch = batch .data() @@ -177,7 +178,7 @@ unsafe fn mmap_dictionary>( .ok_or_else(|| polars_err!(ComputeError: "out-of-spec {:?}", OutOfSpecKind::MissingData))?; let value_type = if let ArrowDataType::Dictionary(_, value_type, _) = - first_field.data_type.to_logical_type() + first_field.dtype.to_logical_type() { value_type.as_ref() } else { @@ -185,10 +186,10 @@ unsafe fn mmap_dictionary>( }; // Make a fake schema for the dictionary batch. - let field = Field::new("", value_type.clone(), false); + let field = Field::new(PlSmallStr::EMPTY, value_type.clone(), false); let chunk = _mmap_record( - &[field], + &std::iter::once((field.name.clone(), field)).collect(), &[first_ipc_field.clone()], data.clone(), batch, diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs index 33b3058cbb78..ae4583dfe6f4 100644 --- a/crates/polars-arrow/src/offset.rs +++ b/crates/polars-arrow/src/offset.rs @@ -518,6 +518,14 @@ impl OffsetsBuffer { pub fn into_inner(self) -> Buffer { self.0 } + + /// Returns the offset difference between `start` and `end`. + #[inline] + pub fn delta(&self, start: usize, end: usize) -> usize { + assert!(start <= end); + + (self.0[end + 1] - self.0[start]).to_usize() + } } impl From<&OffsetsBuffer> for OffsetsBuffer { diff --git a/crates/polars-arrow/src/pushable.rs b/crates/polars-arrow/src/pushable.rs index de642833bc6a..7b8857ab3a15 100644 --- a/crates/polars-arrow/src/pushable.rs +++ b/crates/polars-arrow/src/pushable.rs @@ -209,22 +209,7 @@ where } fn extend_constant(&mut self, additional: usize, value: T) { - let value = value.as_ref(); - // First push a value to get the View - MutableBinaryViewArray::push_value(self, value); - - // And then use that new view to extend - let views = self.views_mut(); - let view = *views.last().unwrap(); - - let remaining = additional - 1; - for _ in 0..remaining { - views.push(view); - } - - if let Some(bitmap) = self.validity() { - bitmap.extend_constant(remaining, true) - } + MutableBinaryViewArray::extend_constant(self, additional, Some(value)); } #[inline] @@ -262,22 +247,7 @@ where } fn extend_constant(&mut self, additional: usize, value: Option) { - let value = value.as_ref(); - // First push a value to get the View - MutableBinaryViewArray::push(self, value); - - // And then use that new view to extend - let views = self.views_mut(); - let view = *views.last().unwrap(); - - let remaining = additional - 1; - for _ in 0..remaining { - views.push(view); - } - - if let Some(bitmap) = self.validity() { - bitmap.extend_constant(remaining, true) - } + MutableBinaryViewArray::extend_constant(self, additional, value); } #[inline] diff --git a/crates/polars-arrow/src/scalar/README.md b/crates/polars-arrow/src/scalar/README.md index b780081b6131..ea6c3791d6be 100644 --- a/crates/polars-arrow/src/scalar/README.md +++ b/crates/polars-arrow/src/scalar/README.md @@ -19,7 +19,7 @@ Specifically, a `Scalar` is a trait object that can be downcasted to concrete im Like `Array`, `Scalar` implements -- `data_type`, which is used to perform the correct downcast +- `dtype`, which is used to perform the correct downcast - `is_valid`, to tell whether the scalar is null or not ### There is one implementation per arrows' physical type diff --git a/crates/polars-arrow/src/scalar/binary.rs b/crates/polars-arrow/src/scalar/binary.rs index bdc1f8b8243a..f758cf021b1c 100644 --- a/crates/polars-arrow/src/scalar/binary.rs +++ b/crates/polars-arrow/src/scalar/binary.rs @@ -45,7 +45,7 @@ impl Scalar for BinaryScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { if O::IS_LARGE { &ArrowDataType::LargeBinary } else { diff --git a/crates/polars-arrow/src/scalar/binview.rs b/crates/polars-arrow/src/scalar/binview.rs index e96c90c04adb..958037041623 100644 --- a/crates/polars-arrow/src/scalar/binview.rs +++ b/crates/polars-arrow/src/scalar/binview.rs @@ -62,7 +62,7 @@ impl Scalar for BinaryViewScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { if T::IS_UTF8 { &ArrowDataType::Utf8View } else { diff --git a/crates/polars-arrow/src/scalar/boolean.rs b/crates/polars-arrow/src/scalar/boolean.rs index 82d1e9c6e7ed..44158d8c3636 100644 --- a/crates/polars-arrow/src/scalar/boolean.rs +++ b/crates/polars-arrow/src/scalar/boolean.rs @@ -33,7 +33,7 @@ impl Scalar for BooleanScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { &ArrowDataType::Boolean } } diff --git a/crates/polars-arrow/src/scalar/dictionary.rs b/crates/polars-arrow/src/scalar/dictionary.rs index f9559009a1c6..b92a99355559 100644 --- a/crates/polars-arrow/src/scalar/dictionary.rs +++ b/crates/polars-arrow/src/scalar/dictionary.rs @@ -9,12 +9,12 @@ use crate::datatypes::ArrowDataType; pub struct DictionaryScalar { value: Option>, phantom: std::marker::PhantomData, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PartialEq for DictionaryScalar { fn eq(&self, other: &Self) -> bool { - (self.data_type == other.data_type) && (self.value.as_ref() == other.value.as_ref()) + (self.dtype == other.dtype) && (self.value.as_ref() == other.value.as_ref()) } } @@ -22,14 +22,14 @@ impl DictionaryScalar { /// returns a new [`DictionaryScalar`] /// # Panics /// iff - /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) - /// * the child of the `data_type` is not equal to the `values` + /// * the `dtype` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `dtype` is not equal to the `values` #[inline] - pub fn new(data_type: ArrowDataType, value: Option>) -> Self { + pub fn new(dtype: ArrowDataType, value: Option>) -> Self { Self { value, phantom: std::marker::PhantomData, - data_type, + dtype, } } @@ -48,7 +48,7 @@ impl Scalar for DictionaryScalar { self.value.is_some() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/equal.rs b/crates/polars-arrow/src/scalar/equal.rs index 78055671b32e..3978765fe73a 100644 --- a/crates/polars-arrow/src/scalar/equal.rs +++ b/crates/polars-arrow/src/scalar/equal.rs @@ -30,12 +30,12 @@ macro_rules! dyn_eq { } fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { - if lhs.data_type() != rhs.data_type() { + if lhs.dtype() != rhs.dtype() { return false; } use PhysicalType::*; - match lhs.data_type().to_physical_type() { + match lhs.dtype().to_physical_type() { Null => dyn_eq!(NullScalar, lhs, rhs), Boolean => dyn_eq!(BooleanScalar, lhs, rhs), Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { diff --git a/crates/polars-arrow/src/scalar/fixed_size_binary.rs b/crates/polars-arrow/src/scalar/fixed_size_binary.rs index 0c6c5602c2bc..a14d2886d75d 100644 --- a/crates/polars-arrow/src/scalar/fixed_size_binary.rs +++ b/crates/polars-arrow/src/scalar/fixed_size_binary.rs @@ -5,31 +5,31 @@ use crate::datatypes::ArrowDataType; /// The [`Scalar`] implementation of fixed size binary ([`Option>`]). pub struct FixedSizeBinaryScalar { value: Option>, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl FixedSizeBinaryScalar { /// Returns a new [`FixedSizeBinaryScalar`]. /// # Panics /// iff - /// * the `data_type` is not `FixedSizeBinary` + /// * the `dtype` is not `FixedSizeBinary` /// * the size of child binary is not equal #[inline] - pub fn new>>(data_type: ArrowDataType, value: Option

) -> Self { + pub fn new>>(dtype: ArrowDataType, value: Option

) -> Self { assert_eq!( - data_type.to_physical_type(), + dtype.to_physical_type(), crate::datatypes::PhysicalType::FixedSizeBinary ); Self { value: value.map(|x| { let x: Vec = x.into(); assert_eq!( - data_type.to_logical_type(), + dtype.to_logical_type(), &ArrowDataType::FixedSizeBinary(x.len()) ); x.into_boxed_slice() }), - data_type, + dtype, } } @@ -52,7 +52,7 @@ impl Scalar for FixedSizeBinaryScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/fixed_size_list.rs b/crates/polars-arrow/src/scalar/fixed_size_list.rs index 0ef0f083943c..5810eeab2dfc 100644 --- a/crates/polars-arrow/src/scalar/fixed_size_list.rs +++ b/crates/polars-arrow/src/scalar/fixed_size_list.rs @@ -9,12 +9,12 @@ use crate::datatypes::ArrowDataType; #[derive(Debug, Clone)] pub struct FixedSizeListScalar { values: Option>, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PartialEq for FixedSizeListScalar { fn eq(&self, other: &Self) -> bool { - (self.data_type == other.data_type) + (self.dtype == other.dtype) && (self.values.is_some() == other.values.is_some()) && ((self.values.is_none()) | (self.values.as_ref() == other.values.as_ref())) } @@ -24,18 +24,18 @@ impl FixedSizeListScalar { /// returns a new [`FixedSizeListScalar`] /// # Panics /// iff - /// * the `data_type` is not `FixedSizeList` - /// * the child of the `data_type` is not equal to the `values` + /// * the `dtype` is not `FixedSizeList` + /// * the child of the `dtype` is not equal to the `values` /// * the size of child array is not equal #[inline] - pub fn new(data_type: ArrowDataType, values: Option>) -> Self { - let (field, size) = FixedSizeListArray::get_child_and_size(&data_type); - let inner_data_type = field.data_type(); + pub fn new(dtype: ArrowDataType, values: Option>) -> Self { + let (field, size) = FixedSizeListArray::get_child_and_size(&dtype); + let inner_dtype = field.dtype(); let values = values.inspect(|x| { - assert_eq!(inner_data_type, x.data_type()); + assert_eq!(inner_dtype, x.dtype()); assert_eq!(size, x.len()); }); - Self { values, data_type } + Self { values, dtype } } /// The values of the [`FixedSizeListScalar`] @@ -53,7 +53,7 @@ impl Scalar for FixedSizeListScalar { self.values.is_some() } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/list.rs b/crates/polars-arrow/src/scalar/list.rs index c58a11150e08..6978c6e61860 100644 --- a/crates/polars-arrow/src/scalar/list.rs +++ b/crates/polars-arrow/src/scalar/list.rs @@ -12,12 +12,12 @@ pub struct ListScalar { values: Box, is_valid: bool, phantom: std::marker::PhantomData, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PartialEq for ListScalar { fn eq(&self, other: &Self) -> bool { - (self.data_type == other.data_type) + (self.dtype == other.dtype) && (self.is_valid == other.is_valid) && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) } @@ -27,23 +27,23 @@ impl ListScalar { /// returns a new [`ListScalar`] /// # Panics /// iff - /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) - /// * the child of the `data_type` is not equal to the `values` + /// * the `dtype` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `dtype` is not equal to the `values` #[inline] - pub fn new(data_type: ArrowDataType, values: Option>) -> Self { - let inner_data_type = ListArray::::get_child_type(&data_type); + pub fn new(dtype: ArrowDataType, values: Option>) -> Self { + let inner_dtype = ListArray::::get_child_type(&dtype); let (is_valid, values) = match values { Some(values) => { - assert_eq!(inner_data_type, values.data_type()); + assert_eq!(inner_dtype, values.dtype()); (true, values) }, - None => (false, new_empty_array(inner_data_type.clone())), + None => (false, new_empty_array(inner_dtype.clone())), }; Self { values, is_valid, phantom: std::marker::PhantomData, - data_type, + dtype, } } @@ -62,7 +62,7 @@ impl Scalar for ListScalar { self.is_valid } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/map.rs b/crates/polars-arrow/src/scalar/map.rs index 6dd204a83e02..f9e7b238c481 100644 --- a/crates/polars-arrow/src/scalar/map.rs +++ b/crates/polars-arrow/src/scalar/map.rs @@ -10,12 +10,12 @@ use crate::datatypes::ArrowDataType; pub struct MapScalar { values: Box, is_valid: bool, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PartialEq for MapScalar { fn eq(&self, other: &Self) -> bool { - (self.data_type == other.data_type) + (self.dtype == other.dtype) && (self.is_valid == other.is_valid) && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) } @@ -25,23 +25,23 @@ impl MapScalar { /// returns a new [`MapScalar`] /// # Panics /// iff - /// * the `data_type` is not `Map` - /// * the child of the `data_type` is not equal to the `values` + /// * the `dtype` is not `Map` + /// * the child of the `dtype` is not equal to the `values` #[inline] - pub fn new(data_type: ArrowDataType, values: Option>) -> Self { - let inner_field = MapArray::try_get_field(&data_type).unwrap(); - let inner_data_type = inner_field.data_type(); + pub fn new(dtype: ArrowDataType, values: Option>) -> Self { + let inner_field = MapArray::try_get_field(&dtype).unwrap(); + let inner_dtype = inner_field.dtype(); let (is_valid, values) = match values { Some(values) => { - assert_eq!(inner_data_type, values.data_type()); + assert_eq!(inner_dtype, values.dtype()); (true, values) }, - None => (false, new_empty_array(inner_data_type.clone())), + None => (false, new_empty_array(inner_dtype.clone())), }; Self { values, is_valid, - data_type, + dtype, } } @@ -60,7 +60,7 @@ impl Scalar for MapScalar { self.is_valid } - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/mod.rs b/crates/polars-arrow/src/scalar/mod.rs index 54bd0705bf54..adcf006862db 100644 --- a/crates/polars-arrow/src/scalar/mod.rs +++ b/crates/polars-arrow/src/scalar/mod.rs @@ -46,7 +46,7 @@ pub trait Scalar: std::fmt::Debug + Send + Sync + dyn_clone::DynClone + 'static fn is_valid(&self) -> bool; /// the logical type. - fn data_type(&self) -> &ArrowDataType; + fn dtype(&self) -> &ArrowDataType; } dyn_clone::clone_trait_object!(Scalar); @@ -101,14 +101,14 @@ macro_rules! dyn_new_list { } else { None }; - Box::new(ListScalar::<$type>::new(array.data_type().clone(), value)) + Box::new(ListScalar::<$type>::new(array.dtype().clone(), value)) }}; } /// creates a new [`Scalar`] from an [`Array`]. pub fn new_scalar(array: &dyn Array, index: usize) -> Box { use PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Null => Box::new(NullScalar::new()), Boolean => { let array = array.as_any().downcast_ref::().unwrap(); @@ -129,7 +129,7 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } else { None }; - Box::new(PrimitiveScalar::new(array.data_type().clone(), value)) + Box::new(PrimitiveScalar::new(array.dtype().clone(), value)) }), BinaryView => dyn_new_binview!(array, index, [u8]), Utf8View => dyn_new_binview!(array, index, str), @@ -147,9 +147,9 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { .iter() .map(|x| new_scalar(x.as_ref(), index)) .collect(); - Box::new(StructScalar::new(array.data_type().clone(), Some(values))) + Box::new(StructScalar::new(array.dtype().clone(), Some(values))) } else { - Box::new(StructScalar::new(array.data_type().clone(), None)) + Box::new(StructScalar::new(array.dtype().clone(), None)) } }, FixedSizeBinary => { @@ -162,7 +162,7 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } else { None }; - Box::new(FixedSizeBinaryScalar::new(array.data_type().clone(), value)) + Box::new(FixedSizeBinaryScalar::new(array.dtype().clone(), value)) }, FixedSizeList => { let array = array.as_any().downcast_ref::().unwrap(); @@ -171,12 +171,12 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } else { None }; - Box::new(FixedSizeListScalar::new(array.data_type().clone(), value)) + Box::new(FixedSizeListScalar::new(array.dtype().clone(), value)) }, Union => { let array = array.as_any().downcast_ref::().unwrap(); Box::new(UnionScalar::new( - array.data_type().clone(), + array.dtype().clone(), array.types()[index], array.value(index), )) @@ -188,7 +188,7 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { } else { None }; - Box::new(MapScalar::new(array.data_type().clone(), value)) + Box::new(MapScalar::new(array.dtype().clone(), value)) }, Dictionary(key_type) => match_integer_type!(key_type, |$T| { let array = array @@ -201,7 +201,7 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { None }; Box::new(DictionaryScalar::<$T>::new( - array.data_type().clone(), + array.dtype().clone(), value, )) }), diff --git a/crates/polars-arrow/src/scalar/null.rs b/crates/polars-arrow/src/scalar/null.rs index 3559d6cc8290..2071f0d4584e 100644 --- a/crates/polars-arrow/src/scalar/null.rs +++ b/crates/polars-arrow/src/scalar/null.rs @@ -31,7 +31,7 @@ impl Scalar for NullScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { &ArrowDataType::Null } } diff --git a/crates/polars-arrow/src/scalar/primitive.rs b/crates/polars-arrow/src/scalar/primitive.rs index b25b09b3ec91..35214b270032 100644 --- a/crates/polars-arrow/src/scalar/primitive.rs +++ b/crates/polars-arrow/src/scalar/primitive.rs @@ -7,21 +7,21 @@ use crate::types::NativeType; #[derive(Debug, Clone, PartialEq, Eq)] pub struct PrimitiveScalar { value: Option, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PrimitiveScalar { /// Returns a new [`PrimitiveScalar`]. #[inline] - pub fn new(data_type: ArrowDataType, value: Option) -> Self { - if !data_type.to_physical_type().eq_primitive(T::PRIMITIVE) { + pub fn new(dtype: ArrowDataType, value: Option) -> Self { + if !dtype.to_physical_type().eq_primitive(T::PRIMITIVE) { panic!( "Type {} does not support logical type {:?}", std::any::type_name::(), - data_type + dtype ) } - Self { value, data_type } + Self { value, dtype } } /// Returns the optional value. @@ -32,9 +32,9 @@ impl PrimitiveScalar { /// Returns a new `PrimitiveScalar` with the same value but different [`ArrowDataType`] /// # Panic - /// This function panics if the `data_type` is not valid for self's physical type `T`. - pub fn to(self, data_type: ArrowDataType) -> Self { - Self::new(data_type, self.value) + /// This function panics if the `dtype` is not valid for self's physical type `T`. + pub fn to(self, dtype: ArrowDataType) -> Self { + Self::new(dtype, self.value) } } @@ -57,7 +57,7 @@ impl Scalar for PrimitiveScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/struct_.rs b/crates/polars-arrow/src/scalar/struct_.rs index c3e249a45d47..c9ba6a8e66c0 100644 --- a/crates/polars-arrow/src/scalar/struct_.rs +++ b/crates/polars-arrow/src/scalar/struct_.rs @@ -6,12 +6,12 @@ use crate::datatypes::ArrowDataType; pub struct StructScalar { values: Vec>, is_valid: bool, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PartialEq for StructScalar { fn eq(&self, other: &Self) -> bool { - (self.data_type == other.data_type) + (self.dtype == other.dtype) && (self.is_valid == other.is_valid) && ((!self.is_valid) | (self.values == other.values)) } @@ -20,12 +20,12 @@ impl PartialEq for StructScalar { impl StructScalar { /// Returns a new [`StructScalar`] #[inline] - pub fn new(data_type: ArrowDataType, values: Option>>) -> Self { + pub fn new(dtype: ArrowDataType, values: Option>>) -> Self { let is_valid = values.is_some(); Self { values: values.unwrap_or_default(), is_valid, - data_type, + dtype, } } @@ -48,7 +48,7 @@ impl Scalar for StructScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/union.rs b/crates/polars-arrow/src/scalar/union.rs index bf22c0cfede2..95f4ebba6e3e 100644 --- a/crates/polars-arrow/src/scalar/union.rs +++ b/crates/polars-arrow/src/scalar/union.rs @@ -6,17 +6,17 @@ use crate::datatypes::ArrowDataType; pub struct UnionScalar { value: Box, type_: i8, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl UnionScalar { /// Returns a new [`UnionScalar`] #[inline] - pub fn new(data_type: ArrowDataType, type_: i8, value: Box) -> Self { + pub fn new(dtype: ArrowDataType, type_: i8, value: Box) -> Self { Self { value, type_, - data_type, + dtype, } } @@ -45,7 +45,7 @@ impl Scalar for UnionScalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { - &self.data_type + fn dtype(&self) -> &ArrowDataType { + &self.dtype } } diff --git a/crates/polars-arrow/src/scalar/utf8.rs b/crates/polars-arrow/src/scalar/utf8.rs index e31c778631fe..986477d5bb5c 100644 --- a/crates/polars-arrow/src/scalar/utf8.rs +++ b/crates/polars-arrow/src/scalar/utf8.rs @@ -45,7 +45,7 @@ impl Scalar for Utf8Scalar { } #[inline] - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { if O::IS_LARGE { &ArrowDataType::LargeUtf8 } else { diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs index 2299ffb5ed07..b5672f6dd626 100644 --- a/crates/polars-arrow/src/temporal_conversions.rs +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -3,6 +3,7 @@ use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta}; use polars_error::{polars_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use crate::array::{PrimitiveArray, Utf8ViewArray}; use crate::datatypes::{ArrowDataType, TimeUnit}; @@ -267,6 +268,7 @@ pub fn parse_offset(offset: &str) -> PolarsResult { } /// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// /// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). /// Returns in scale `tz` of `TimeUnit`. #[inline] @@ -317,7 +319,7 @@ pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> fn utf8view_to_timestamp_impl( array: &Utf8ViewArray, fmt: &str, - time_zone: String, + time_zone: PlSmallStr, tz: T, time_unit: TimeUnit, ) -> PrimitiveArray { @@ -338,15 +340,25 @@ pub fn parse_offset_tz(timezone: &str) -> PolarsResult { .map_err(|_| polars_err!(InvalidOperation: "timezone \"{timezone}\" cannot be parsed")) } +/// Get the time unit as a multiple of a second +pub const fn time_unit_multiple(unit: TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + #[cfg(feature = "chrono-tz")] #[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] fn chrono_tz_utf_to_timestamp( array: &Utf8ViewArray, fmt: &str, - time_zone: String, + time_zone: PlSmallStr, time_unit: TimeUnit, ) -> PolarsResult> { - let tz = parse_offset_tz(&time_zone)?; + let tz = parse_offset_tz(time_zone.as_str())?; Ok(utf8view_to_timestamp_impl( array, fmt, time_zone, tz, time_unit, )) @@ -356,7 +368,7 @@ fn chrono_tz_utf_to_timestamp( fn chrono_tz_utf_to_timestamp( _: &Utf8ViewArray, _: &str, - timezone: String, + timezone: PlSmallStr, _: TimeUnit, ) -> PolarsResult> { panic!("timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)") @@ -378,7 +390,7 @@ fn chrono_tz_utf_to_timestamp( pub(crate) fn utf8view_to_timestamp( array: &Utf8ViewArray, fmt: &str, - time_zone: String, + time_zone: PlSmallStr, time_unit: TimeUnit, ) -> PolarsResult> { let tz = parse_offset(time_zone.as_str()); diff --git a/crates/polars-arrow/src/trusted_len.rs b/crates/polars-arrow/src/trusted_len.rs index 5f194770e7c4..359edfd1b88c 100644 --- a/crates/polars-arrow/src/trusted_len.rs +++ b/crates/polars-arrow/src/trusted_len.rs @@ -3,6 +3,7 @@ use std::iter::Scan; use std::slice::Iter; /// An iterator of known, fixed size. +/// /// A trait denoting Rusts' unstable [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// This is re-defined here and implemented for some iterators until `std::iter::TrustedLen` /// is stabilized. @@ -98,6 +99,14 @@ where } } +impl TrustMyLength>, J> { + /// Create a new `TrustMyLength` iterator that repeats `value` `len` times. + pub fn new_repeat_n(value: J, len: usize) -> Self { + // SAFETY: This is always safe since repeat(..).take(n) always repeats exactly `n` times`. + unsafe { Self::new(std::iter::repeat(value).take(len), len) } + } +} + impl Iterator for TrustMyLength where I: Iterator, diff --git a/crates/polars-arrow/src/types/bit_chunk.rs b/crates/polars-arrow/src/types/bit_chunk.rs index c618c5458515..be4445a5d77a 100644 --- a/crates/polars-arrow/src/types/bit_chunk.rs +++ b/crates/polars-arrow/src/types/bit_chunk.rs @@ -48,8 +48,10 @@ bit_chunk!(u16); bit_chunk!(u32); bit_chunk!(u64); -/// An [`Iterator`] over a [`BitChunk`]. This iterator is often -/// compiled to SIMD. +/// An [`Iterator`] over a [`BitChunk`]. +/// +/// This iterator is often compiled to SIMD. +/// /// The [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit) corresponds /// to the first slot, as defined by the arrow specification. /// # Example diff --git a/crates/polars-compute/src/arithmetic/signed.rs b/crates/polars-compute/src/arithmetic/signed.rs index a19f6b231526..c77057b78a47 100644 --- a/crates/polars-compute/src/arithmetic/signed.rs +++ b/crates/polars-compute/src/arithmetic/signed.rs @@ -106,7 +106,7 @@ macro_rules! impl_signed_arith_kernel { fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { if rhs == 0 { - PArr::full_null(lhs.len(), lhs.data_type().clone()) + PArr::full_null(lhs.len(), lhs.dtype().clone()) } else if rhs == -1 { Self::prim_wrapping_neg(lhs) } else if rhs == 1 { @@ -145,7 +145,7 @@ macro_rules! impl_signed_arith_kernel { fn prim_wrapping_trunc_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { if rhs == 0 { - PArr::full_null(lhs.len(), lhs.data_type().clone()) + PArr::full_null(lhs.len(), lhs.dtype().clone()) } else if rhs == -1 { Self::prim_wrapping_neg(lhs) } else if rhs == 1 { @@ -177,7 +177,7 @@ macro_rules! impl_signed_arith_kernel { fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { if rhs == 0 { - PArr::full_null(lhs.len(), lhs.data_type().clone()) + PArr::full_null(lhs.len(), lhs.dtype().clone()) } else if rhs == -1 || rhs == 1 { lhs.fill_with(0) } else { diff --git a/crates/polars-compute/src/arithmetic/unsigned.rs b/crates/polars-compute/src/arithmetic/unsigned.rs index 2ae40332e820..db71590989bd 100644 --- a/crates/polars-compute/src/arithmetic/unsigned.rs +++ b/crates/polars-compute/src/arithmetic/unsigned.rs @@ -85,7 +85,7 @@ macro_rules! impl_unsigned_arith_kernel { fn prim_wrapping_floor_div_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { if rhs == 0 { - PArr::full_null(lhs.len(), lhs.data_type().clone()) + PArr::full_null(lhs.len(), lhs.dtype().clone()) } else if rhs == 1 { lhs } else { @@ -115,7 +115,7 @@ macro_rules! impl_unsigned_arith_kernel { fn prim_wrapping_mod_scalar(lhs: PArr<$T>, rhs: $T) -> PArr<$T> { if rhs == 0 { - PArr::full_null(lhs.len(), lhs.data_type().clone()) + PArr::full_null(lhs.len(), lhs.dtype().clone()) } else if rhs == 1 { lhs.fill_with(0) } else { diff --git a/crates/polars-compute/src/comparisons/array.rs b/crates/polars-compute/src/comparisons/array.rs index b981a50b3547..da120f27553b 100644 --- a/crates/polars-compute/src/comparisons/array.rs +++ b/crates/polars-compute/src/comparisons/array.rs @@ -34,17 +34,15 @@ impl TotalEqKernel for FixedSizeListArray { // make any sense. assert_eq!(self.len(), other.len()); - let ArrowDataType::FixedSizeList(self_type, self_width) = - self.data_type().to_logical_type() + let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_logical_type() else { panic!("array comparison called with non-array type"); }; - let ArrowDataType::FixedSizeList(other_type, other_width) = - other.data_type().to_logical_type() + let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_logical_type() else { panic!("array comparison called with non-array type"); }; - assert_eq!(self_type.data_type(), other_type.data_type()); + assert_eq!(self_type.dtype(), other_type.dtype()); if self_width != other_width { return Bitmap::new_with_value(false, self.len()); @@ -57,17 +55,15 @@ impl TotalEqKernel for FixedSizeListArray { fn tot_ne_kernel(&self, other: &Self) -> Bitmap { assert_eq!(self.len(), other.len()); - let ArrowDataType::FixedSizeList(self_type, self_width) = - self.data_type().to_logical_type() + let ArrowDataType::FixedSizeList(self_type, self_width) = self.dtype().to_logical_type() else { panic!("array comparison called with non-array type"); }; - let ArrowDataType::FixedSizeList(other_type, other_width) = - other.data_type().to_logical_type() + let ArrowDataType::FixedSizeList(other_type, other_width) = other.dtype().to_logical_type() else { panic!("array comparison called with non-array type"); }; - assert_eq!(self_type.data_type(), other_type.data_type()); + assert_eq!(self_type.dtype(), other_type.dtype()); if self_width != other_width { return Bitmap::new_with_value(true, self.len()); diff --git a/crates/polars-compute/src/comparisons/dyn_array.rs b/crates/polars-compute/src/comparisons/dyn_array.rs index 693293f4e2c5..3ee3d802f09f 100644 --- a/crates/polars-compute/src/comparisons/dyn_array.rs +++ b/crates/polars-compute/src/comparisons/dyn_array.rs @@ -20,10 +20,10 @@ macro_rules! compare { let lhs = $lhs; let rhs = $rhs; - assert_eq!(lhs.data_type(), rhs.data_type()); + assert_eq!(lhs.dtype(), rhs.dtype()); use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; - match lhs.data_type().to_physical_type() { + match lhs.dtype().to_physical_type() { PH::Boolean => call_binary!(BooleanArray, lhs, rhs, $op), PH::BinaryView => call_binary!(BinaryViewArray, lhs, rhs, $op), PH::Utf8View => call_binary!(Utf8ViewArray, lhs, rhs, $op), diff --git a/crates/polars-compute/src/filter/mod.rs b/crates/polars-compute/src/filter/mod.rs index 2ac66243fb8e..6de1afbab2ed 100644 --- a/crates/polars-compute/src/filter/mod.rs +++ b/crates/polars-compute/src/filter/mod.rs @@ -28,14 +28,14 @@ pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box { // Fast-path: completely empty or completely full mask. let false_count = mask.unset_bits(); if false_count == mask.len() { - return new_empty_array(array.data_type().clone()); + return new_empty_array(array.dtype().clone()); } if false_count == 0 { return array.to_boxed(); } use arrow::datatypes::PhysicalType::*; - match array.data_type().to_physical_type() { + match array.dtype().to_physical_type() { Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap(); let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask); @@ -45,7 +45,7 @@ pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box { let array = array.as_any().downcast_ref::().unwrap(); let (values, validity) = boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask); - BooleanArray::new(array.data_type().clone(), values, validity).boxed() + BooleanArray::new(array.dtype().clone(), values, validity).boxed() }, BinaryView => { let array = array.as_any().downcast_ref::().unwrap(); @@ -54,7 +54,7 @@ pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box { let (views, validity) = primitive::filter_values_and_validity(views, validity, mask); unsafe { BinaryViewArray::new_unchecked_unknown_md( - array.data_type().clone(), + array.dtype().clone(), views.into(), array.data_buffers().clone(), validity, diff --git a/crates/polars-compute/src/if_then_else/array.rs b/crates/polars-compute/src/if_then_else/array.rs index a15349bf1e2c..67f9b450ec7c 100644 --- a/crates/polars-compute/src/if_then_else/array.rs +++ b/crates/polars-compute/src/if_then_else/array.rs @@ -26,7 +26,7 @@ impl IfThenElseKernel for FixedSizeListArray { if_false: &Self, ) -> Self { let if_true_list: FixedSizeListArray = - std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.data_type().clone()); + std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.dtype().clone()); let mut growable = GrowableFixedSizeList::new(vec![&if_true_list, if_false], false, mask.len()); unsafe { @@ -46,7 +46,7 @@ impl IfThenElseKernel for FixedSizeListArray { if_false: Self::Scalar<'_>, ) -> Self { let if_false_list: FixedSizeListArray = - std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.data_type().clone()); + std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.dtype().clone()); let mut growable = GrowableFixedSizeList::new(vec![if_true, &if_false_list], false, mask.len()); unsafe { diff --git a/crates/polars-compute/src/if_then_else/list.rs b/crates/polars-compute/src/if_then_else/list.rs index aa3096c6f07e..284d6b7f0420 100644 --- a/crates/polars-compute/src/if_then_else/list.rs +++ b/crates/polars-compute/src/if_then_else/list.rs @@ -26,7 +26,7 @@ impl IfThenElseKernel for ListArray { if_false: &Self, ) -> Self { let if_true_list: ListArray = - std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.data_type().clone()); + std::iter::once(if_true).collect_arr_trusted_with_dtype(if_false.dtype().clone()); let mut growable = GrowableList::new(vec![&if_true_list, if_false], false, mask.len()); unsafe { if_then_else_extend( @@ -45,7 +45,7 @@ impl IfThenElseKernel for ListArray { if_false: Self::Scalar<'_>, ) -> Self { let if_false_list: ListArray = - std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.data_type().clone()); + std::iter::once(if_false).collect_arr_trusted_with_dtype(if_true.dtype().clone()); let mut growable = GrowableList::new(vec![if_true, &if_false_list], false, mask.len()); unsafe { if_then_else_extend( diff --git a/crates/polars-compute/src/if_then_else/mod.rs b/crates/polars-compute/src/if_then_else/mod.rs index c6c752483330..8265422fb9de 100644 --- a/crates/polars-compute/src/if_then_else/mod.rs +++ b/crates/polars-compute/src/if_then_else/mod.rs @@ -100,7 +100,7 @@ impl IfThenElseKernel for PrimitiveArray { } } -fn if_then_else_validity( +pub fn if_then_else_validity( mask: &Bitmap, if_true: Option<&Bitmap>, if_false: Option<&Bitmap>, diff --git a/crates/polars-compute/src/if_then_else/view.rs b/crates/polars-compute/src/if_then_else/view.rs index 25ff67f401aa..5b3fd8fc4df9 100644 --- a/crates/polars-compute/src/if_then_else/view.rs +++ b/crates/polars-compute/src/if_then_else/view.rs @@ -6,6 +6,7 @@ use arrow::array::{Array, BinaryViewArray, MutablePlBinary, Utf8ViewArray, View} use arrow::bitmap::Bitmap; use arrow::buffer::Buffer; use arrow::datatypes::ArrowDataType; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; use super::IfThenElseKernel; use crate::if_then_else::scalar::if_then_else_broadcast_both_scalar_64; @@ -28,12 +29,25 @@ fn make_buffer_and_views( (views, buf) } +fn has_duplicate_buffers(bufs: &[Buffer]) -> bool { + let mut has_duplicate_buffers = false; + let mut bufset = PlHashSet::new(); + for buf in bufs { + if !bufset.insert(buf.as_ptr()) { + has_duplicate_buffers = true; + break; + } + } + has_duplicate_buffers +} + impl IfThenElseKernel for BinaryViewArray { type Scalar<'a> = &'a [u8]; fn if_then_else(mask: &Bitmap, if_true: &Self, if_false: &Self) -> Self { let combined_buffers: Arc<_>; let false_buffer_idx_offset: u32; + let mut has_duplicate_bufs = false; if Arc::ptr_eq(if_true.data_buffers(), if_false.data_buffers()) { // Share exact same buffers, no need to combine. combined_buffers = if_true.data_buffers().clone(); @@ -42,7 +56,9 @@ impl IfThenElseKernel for BinaryViewArray { // Put false buffers after true buffers. let true_buffers = if_true.data_buffers().iter().cloned(); let false_buffers = if_false.data_buffers().iter().cloned(); + combined_buffers = true_buffers.chain(false_buffers).collect(); + has_duplicate_bufs = has_duplicate_buffers(&combined_buffers); false_buffer_idx_offset = if_true.data_buffers().len() as u32; } @@ -57,14 +73,21 @@ impl IfThenElseKernel for BinaryViewArray { let validity = super::if_then_else_validity(mask, if_true.validity(), if_false.validity()); let mut builder = MutablePlBinary::with_capacity(views.len()); - unsafe { - builder.extend_non_null_views_trusted_len_unchecked( - views.into_iter(), - combined_buffers.deref(), - ) - }; + + if has_duplicate_bufs { + unsafe { + builder.extend_non_null_views_unchecked_dedupe( + views.into_iter(), + combined_buffers.deref(), + ) + }; + } else { + unsafe { + builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref()) + }; + } builder - .freeze_with_dtype(if_true.data_type().clone()) + .freeze_with_dtype(if_true.dtype().clone()) .with_validity(validity) } @@ -90,14 +113,19 @@ impl IfThenElseKernel for BinaryViewArray { let validity = super::if_then_else_validity(mask, None, if_false.validity()); let mut builder = MutablePlBinary::with_capacity(views.len()); + unsafe { - builder.extend_non_null_views_trusted_len_unchecked( - views.into_iter(), - combined_buffers.deref(), - ) - }; + if has_duplicate_buffers(&combined_buffers) { + builder.extend_non_null_views_unchecked_dedupe( + views.into_iter(), + combined_buffers.deref(), + ) + } else { + builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref()) + } + } builder - .freeze_with_dtype(if_false.data_type().clone()) + .freeze_with_dtype(if_false.dtype().clone()) .with_validity(validity) } @@ -125,13 +153,17 @@ impl IfThenElseKernel for BinaryViewArray { let mut builder = MutablePlBinary::with_capacity(views.len()); unsafe { - builder.extend_non_null_views_trusted_len_unchecked( - views.into_iter(), - combined_buffers.deref(), - ) + if has_duplicate_buffers(&combined_buffers) { + builder.extend_non_null_views_unchecked_dedupe( + views.into_iter(), + combined_buffers.deref(), + ) + } else { + builder.extend_non_null_views_unchecked(views.into_iter(), combined_buffers.deref()) + } }; builder - .freeze_with_dtype(if_true.data_type().clone()) + .freeze_with_dtype(if_true.dtype().clone()) .with_validity(validity) } @@ -152,7 +184,11 @@ impl IfThenElseKernel for BinaryViewArray { let mut builder = MutablePlBinary::with_capacity(views.len()); unsafe { - builder.extend_non_null_views_trusted_len_unchecked(views.into_iter(), buffers.deref()) + if has_duplicate_buffers(&buffers) { + builder.extend_non_null_views_unchecked_dedupe(views.into_iter(), buffers.deref()) + } else { + builder.extend_non_null_views_unchecked(views.into_iter(), buffers.deref()) + } }; builder.freeze_with_dtype(dtype) } diff --git a/crates/polars-compute/src/min_max/dyn_array.rs b/crates/polars-compute/src/min_max/dyn_array.rs index e988bbd0ef54..20af38dedc63 100644 --- a/crates/polars-compute/src/min_max/dyn_array.rs +++ b/crates/polars-compute/src/min_max/dyn_array.rs @@ -12,8 +12,7 @@ macro_rules! call_op { }}; (dt: $T:ty, $scalar:ty, $arr:expr, $op:path) => {{ let arr: &$T = $arr.as_any().downcast_ref().unwrap(); - $op(arr) - .map(|v| Box::new(<$scalar>::new(arr.data_type().clone(), Some(v))) as Box) + $op(arr).map(|v| Box::new(<$scalar>::new(arr.dtype().clone(), Some(v))) as Box) }}; ($T:ty, $scalar:ty, $arr:expr, $op:path, ret_two) => {{ let arr: &$T = $arr.as_any().downcast_ref().unwrap(); @@ -28,8 +27,8 @@ macro_rules! call_op { let arr: &$T = $arr.as_any().downcast_ref().unwrap(); $op(arr).map(|(l, r)| { ( - Box::new(<$scalar>::new(arr.data_type().clone(), Some(l))) as Box, - Box::new(<$scalar>::new(arr.data_type().clone(), Some(r))) as Box, + Box::new(<$scalar>::new(arr.dtype().clone(), Some(l))) as Box, + Box::new(<$scalar>::new(arr.dtype().clone(), Some(r))) as Box, ) }) }}; @@ -42,7 +41,7 @@ macro_rules! call { use arrow::datatypes::{PhysicalType as PH, PrimitiveType as PR}; use PrimitiveArray as PArr; use PrimitiveScalar as PScalar; - match arr.data_type().to_physical_type() { + match arr.dtype().to_physical_type() { PH::Boolean => call_op!(BooleanArray, BooleanScalar, arr, $op$(, $variant)?), PH::Primitive(PR::Int8) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), PH::Primitive(PR::Int16) => call_op!(dt: PArr, PScalar, arr, $op$(, $variant)?), @@ -65,7 +64,7 @@ macro_rules! call { PH::Utf8 => call_op!(Utf8Array, BinaryScalar, arr, $op$(, $variant)?), PH::LargeUtf8 => call_op!(Utf8Array, BinaryScalar, arr, $op$(, $variant)?), - _ => todo!("Dynamic MinMax is not yet implemented for {:?}", arr.data_type()), + _ => todo!("Dynamic MinMax is not yet implemented for {:?}", arr.dtype()), } }}; } diff --git a/crates/polars-compute/src/unique/boolean.rs b/crates/polars-compute/src/unique/boolean.rs index 511a45bcea00..bee639d2bacf 100644 --- a/crates/polars-compute/src/unique/boolean.rs +++ b/crates/polars-compute/src/unique/boolean.rs @@ -7,7 +7,7 @@ use super::{GenericUniqueKernel, RangedUniqueKernel}; pub struct BooleanUniqueKernelState { seen: u32, has_null: bool, - data_type: ArrowDataType, + dtype: ArrowDataType, } const fn to_value(scalar: Option) -> u8 { @@ -19,11 +19,11 @@ const fn to_value(scalar: Option) -> u8 { } impl BooleanUniqueKernelState { - pub fn new(has_null: bool, data_type: ArrowDataType) -> Self { + pub fn new(has_null: bool, dtype: ArrowDataType) -> Self { Self { seen: 0, has_null, - data_type, + dtype, } } @@ -91,7 +91,7 @@ impl RangedUniqueKernel for BooleanUniqueKernelState { let values = values.freeze(); - BooleanArray::new(self.data_type, values, validity) + BooleanArray::new(self.dtype, values, validity) } fn finalize_n_unique(self) -> usize { @@ -105,22 +105,19 @@ impl RangedUniqueKernel for BooleanUniqueKernelState { impl GenericUniqueKernel for BooleanArray { fn unique(&self) -> Self { - let mut state = - BooleanUniqueKernelState::new(self.null_count() > 0, self.data_type().clone()); + let mut state = BooleanUniqueKernelState::new(self.null_count() > 0, self.dtype().clone()); state.append(self); state.finalize_unique() } fn n_unique(&self) -> usize { - let mut state = - BooleanUniqueKernelState::new(self.null_count() > 0, self.data_type().clone()); + let mut state = BooleanUniqueKernelState::new(self.null_count() > 0, self.dtype().clone()); state.append(self); state.finalize_n_unique() } fn n_unique_non_null(&self) -> usize { - let mut state = - BooleanUniqueKernelState::new(self.null_count() > 0, self.data_type().clone()); + let mut state = BooleanUniqueKernelState::new(self.null_count() > 0, self.dtype().clone()); state.append(self); state.finalize_n_unique_non_null() } diff --git a/crates/polars-compute/src/unique/primitive.rs b/crates/polars-compute/src/unique/primitive.rs index 9a1e4ff933bb..c1e258f800fa 100644 --- a/crates/polars-compute/src/unique/primitive.rs +++ b/crates/polars-compute/src/unique/primitive.rs @@ -16,19 +16,14 @@ pub struct PrimitiveRangedUniqueState { seen: u128, range: RangeInclusive, has_null: bool, - data_type: ArrowDataType, + dtype: ArrowDataType, } impl PrimitiveRangedUniqueState where T: Add + Sub + FromPrimitive + IsFloat, { - pub fn new( - min_value: T, - max_value: T, - has_null: bool, - data_type: ArrowDataType, - ) -> Option { + pub fn new(min_value: T, max_value: T, has_null: bool, dtype: ArrowDataType) -> Option { // We cannot really do this for floating point number as these are not as discrete as // integers. if T::is_float() { @@ -46,7 +41,7 @@ where seen: 0, range: min_value..=max_value, has_null, - data_type, + dtype, }) } @@ -163,7 +158,7 @@ where (values, None) }; - PrimitiveArray::new(self.data_type, values.into(), validity) + PrimitiveArray::new(self.dtype, values.into(), validity) } fn finalize_n_unique(self) -> usize { diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 3fa4152dd877..f6309d07d8c3 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -12,6 +12,7 @@ description = "Core of the Polars DataFrame library" polars-compute = { workspace = true } polars-error = { workspace = true } polars-row = { workspace = true } +polars-schema = { workspace = true } polars-utils = { workspace = true } ahash = { workspace = true } @@ -36,7 +37,6 @@ scopeguard = { workspace = true } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } -smartstring = { workspace = true } thiserror = { workspace = true } xxhash-rust = { workspace = true } @@ -65,7 +65,7 @@ performant = ["arrow/performant", "reinterpret"] # extra utilities for StringChunked strings = ["regex", "arrow/strings", "polars-error/regex"] # support for ObjectChunked (downcastable Series of any type) -object = ["serde_json"] +object = ["serde_json", "algorithm_group_by"] fmt = ["comfy-table/tty"] fmt_no_tty = ["comfy-table"] @@ -94,9 +94,9 @@ diagonal_concat = [] dataframe_arithmetic = [] product = [] unique_counts = [] -partition_by = [] +partition_by = ["algorithm_group_by"] describe = [] -timezones = ["chrono-tz", "arrow/chrono-tz", "arrow/timezones"] +timezones = ["temporal", "chrono", "chrono-tz", "arrow/chrono-tz", "arrow/timezones"] dynamic_group_by = ["dtype-datetime", "dtype-date"] arrow_rs = ["arrow-array", "arrow/arrow_rs"] @@ -118,8 +118,8 @@ dtype-struct = [] bigidx = ["arrow/bigidx", "polars-utils/bigidx"] python = [] -serde = ["dep:serde", "smartstring/serde", "bitflags/serde"] -serde-lazy = ["serde", "arrow/serde", "indexmap/serde", "smartstring/serde", "chrono/serde"] +serde = ["dep:serde", "bitflags/serde", "polars-schema/serde"] +serde-lazy = ["serde", "arrow/serde", "indexmap/serde", "chrono/serde"] docs-selection = [ "ndarray", diff --git a/crates/polars-core/src/chunked_array/arithmetic/mod.rs b/crates/polars-core/src/chunked_array/arithmetic/mod.rs index ab4b3e7eb337..e45c12ef12f1 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/mod.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/mod.rs @@ -76,7 +76,7 @@ impl Add for &BinaryChunked { unsafe { std::mem::transmute::<_, &'static [u8]>(out) } }) }, - None => BinaryChunked::full_null(self.name(), self.len()), + None => BinaryChunked::full_null(self.name().clone(), self.len()), }; } // broadcasting path lhs @@ -91,7 +91,7 @@ impl Add for &BinaryChunked { // ref is valid for the lifetime of this closure. unsafe { std::mem::transmute::<_, &'static [u8]>(out) } }), - None => BinaryChunked::full_null(self.name(), rhs.len()), + None => BinaryChunked::full_null(self.name().clone(), rhs.len()), }; } @@ -137,7 +137,7 @@ impl Add for &BooleanChunked { let rhs = rhs.get(0); return match rhs { Some(rhs) => unary_elementwise_values(self, |v| v as IdxSize + rhs as IdxSize), - None => IdxCa::full_null(self.name(), self.len()), + None => IdxCa::full_null(self.name().clone(), self.len()), }; } // Broadcasting path lhs. @@ -161,10 +161,10 @@ pub(crate) mod test { use crate::prelude::*; pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { - let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); - let a2 = Int32Chunked::new("a", &[4, 5, 6]); - let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); - a1.append(&a2); + let mut a1 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let a2 = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]); + let a3 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3, 4, 5, 6]); + a1.append(&a2).unwrap(); (a1, a3) } diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index 52785f1208db..599de4ca3ec4 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -26,7 +26,7 @@ impl ArrayChunked { /// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive /// longer than the iterator is UB. pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { - self.amortized_iter_with_name("") + self.amortized_iter_with_name(PlSmallStr::EMPTY) } /// This is an iterator over a [`ArrayChunked`] that save allocations. @@ -44,7 +44,7 @@ impl ArrayChunked { /// will be set. pub fn amortized_iter_with_name( &self, - name: &str, + name: PlSmallStr, ) -> AmortizedListIter> + '_> { // we create the series container from the inner array // so that the container has the proper dtype. @@ -84,7 +84,7 @@ impl ArrayChunked { { if self.is_empty() { return Ok(Series::new_empty( - self.name(), + self.name().clone(), &DataType::List(Box::new(self.inner_dtype().clone())), ) .list() @@ -109,7 +109,7 @@ impl ArrayChunked { }) .collect::>()? }; - ca.rename(self.name()); + ca.rename(self.name().clone()); if fast_explode { ca.set_fast_explode(); } @@ -135,7 +135,7 @@ impl ArrayChunked { to_arr(&out) }) }) - .collect_ca_with_dtype(self.name(), self.dtype().clone()) + .collect_ca_with_dtype(self.name().clone(), self.dtype().clone()) } /// Try apply a closure `F` to each array. @@ -158,7 +158,7 @@ impl ArrayChunked { }) .transpose() }) - .try_collect_ca_with_dtype(self.name(), self.dtype().clone()) + .try_collect_ca_with_dtype(self.name().clone(), self.dtype().clone()) } /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. @@ -184,7 +184,7 @@ impl ArrayChunked { let out = f(opt_s, opt_v); out.map(|s| to_arr(&s)) }) - .collect_ca_with_dtype(self.name(), self.dtype().clone()) + .collect_ca_with_dtype(self.name().clone(), self.dtype().clone()) } /// Apply a closure `F` elementwise. @@ -196,7 +196,7 @@ impl ArrayChunked { V::Array: ArrayFromIter>, { { - self.amortized_iter().map(f).collect_ca(self.name()) + self.amortized_iter().map(f).collect_ca(self.name().clone()) } } @@ -208,7 +208,9 @@ impl ArrayChunked { V::Array: ArrayFromIter>, { { - self.amortized_iter().map(f).try_collect_ca(self.name()) + self.amortized_iter() + .map(f) + .try_collect_ca(self.name().clone()) } } diff --git a/crates/polars-core/src/chunked_array/array/mod.rs b/crates/polars-core/src/chunked_array/array/mod.rs index a3b7a1a1f339..59bdd92b67cc 100644 --- a/crates/polars-core/src/chunked_array/array/mod.rs +++ b/crates/polars-core/src/chunked_array/array/mod.rs @@ -34,7 +34,9 @@ impl ArrayChunked { let chunks: Vec<_> = self.downcast_iter().map(|c| c.values().clone()).collect(); // SAFETY: Data type of arrays matches because they are chunks from the same array. - unsafe { Series::from_chunks_and_dtype_unchecked(self.name(), chunks, self.inner_dtype()) } + unsafe { + Series::from_chunks_and_dtype_unchecked(self.name().clone(), chunks, self.inner_dtype()) + } } /// Ignore the list indices and apply `func` to the inner type as [`Series`]. @@ -46,14 +48,14 @@ impl ArrayChunked { let ca = self.rechunk(); let field = self .inner_dtype() - .to_arrow_field("item", CompatLevel::newest()); + .to_arrow_field(PlSmallStr::from_static("item"), CompatLevel::newest()); let chunks = ca.downcast_iter().map(|arr| { let elements = unsafe { Series::_try_from_arrow_unchecked_with_md( - self.name(), + self.name().clone(), vec![(*arr.values()).clone()], - &field.data_type, + &field.dtype, Some(&field.metadata), ) .unwrap() @@ -76,6 +78,6 @@ impl ArrayChunked { Ok(arr) }); - ArrayChunked::try_from_chunk_iter(self.name(), chunks) + ArrayChunked::try_from_chunk_iter(self.name().clone(), chunks) } } diff --git a/crates/polars-core/src/chunked_array/binary.rs b/crates/polars-core/src/chunked_array/binary.rs index 30b492174ec7..27ab78d137a5 100644 --- a/crates/polars-core/src/chunked_array/binary.rs +++ b/crates/polars-core/src/chunked_array/binary.rs @@ -1,4 +1,4 @@ -use ahash::RandomState; +use polars_utils::aliases::PlRandomState; use polars_utils::hashing::BytesHash; use rayon::prelude::*; @@ -8,7 +8,11 @@ use crate::utils::{_set_partition_size, _split_offsets}; use crate::POOL; #[inline] -fn fill_bytes_hashes<'a, T>(ca: &'a ChunkedArray, null_h: u64, hb: RandomState) -> Vec +fn fill_bytes_hashes<'a, T>( + ca: &'a ChunkedArray, + null_h: u64, + hb: PlRandomState, +) -> Vec where T: PolarsDataType, <::Array as StaticArray>::ValueT<'a>: AsRef<[u8]>, @@ -39,7 +43,7 @@ where pub fn to_bytes_hashes<'a>( &'a self, mut multithreaded: bool, - hb: RandomState, + hb: PlRandomState, ) -> Vec>> { multithreaded &= POOL.current_num_threads() > 1; let null_h = get_null_hash_value(&hb); diff --git a/crates/polars-core/src/chunked_array/bitwise.rs b/crates/polars-core/src/chunked_array/bitwise.rs index 9e8fc482498c..ad179e0c1c28 100644 --- a/crates/polars-core/src/chunked_array/bitwise.rs +++ b/crates/polars-core/src/chunked_array/bitwise.rs @@ -71,10 +71,10 @@ impl BitOr for &BooleanChunked { (1, 1) => {}, (1, _) => { return match self.get(0) { - Some(true) => BooleanChunked::full(self.name(), true, rhs.len()), + Some(true) => BooleanChunked::full(self.name().clone(), true, rhs.len()), Some(false) => { let mut rhs = rhs.clone(); - rhs.rename(self.name()); + rhs.rename(self.name().clone()); rhs }, None => &self.new_from_index(0, rhs.len()) | rhs, @@ -82,9 +82,9 @@ impl BitOr for &BooleanChunked { }, (_, 1) => { return match rhs.get(0) { - Some(true) => BooleanChunked::full(self.name(), true, self.len()), + Some(true) => BooleanChunked::full(self.name().clone(), true, self.len()), Some(false) => self.clone(), - None => &rhs.new_from_index(0, self.len()) | self, + None => self | &rhs.new_from_index(0, self.len()), }; }, _ => {}, @@ -114,12 +114,12 @@ impl BitXor for &BooleanChunked { return match self.get(0) { Some(true) => { let mut rhs = rhs.not(); - rhs.rename(self.name()); + rhs.rename(self.name().clone()); rhs }, Some(false) => { let mut rhs = rhs.clone(); - rhs.rename(self.name()); + rhs.rename(self.name().clone()); rhs }, None => &self.new_from_index(0, rhs.len()) | rhs, @@ -129,7 +129,7 @@ impl BitXor for &BooleanChunked { return match rhs.get(0) { Some(true) => self.not(), Some(false) => self.clone(), - None => &rhs.new_from_index(0, self.len()) | self, + None => self | &rhs.new_from_index(0, self.len()), }; }, _ => {}, @@ -161,15 +161,15 @@ impl BitAnd for &BooleanChunked { (1, 1) => {}, (1, _) => { return match self.get(0) { - Some(true) => rhs.clone().with_name(self.name()), - Some(false) => BooleanChunked::full(self.name(), false, rhs.len()), + Some(true) => rhs.clone().with_name(self.name().clone()), + Some(false) => BooleanChunked::full(self.name().clone(), false, rhs.len()), None => &self.new_from_index(0, rhs.len()) & rhs, }; }, (_, 1) => { return match rhs.get(0) { Some(true) => self.clone(), - Some(false) => BooleanChunked::full(self.name(), false, self.len()), + Some(false) => BooleanChunked::full(self.name().clone(), false, self.len()), None => self & &rhs.new_from_index(0, self.len()), }; }, @@ -195,8 +195,8 @@ mod test { #[test] fn guard_so_issue_2494() { // this cause a stack overflow - let a = BooleanChunked::new("a", [None]); - let b = BooleanChunked::new("b", [None]); + let a = BooleanChunked::new(PlSmallStr::from_static("a"), [None]); + let b = BooleanChunked::new(PlSmallStr::from_static("b"), [None]); assert_eq!((&a).bitand(&b).null_count(), 1); assert_eq!((&a).bitor(&b).null_count(), 1); diff --git a/crates/polars-core/src/chunked_array/builder/boolean.rs b/crates/polars-core/src/chunked_array/builder/boolean.rs index 031a45c8d74f..649db3d6252e 100644 --- a/crates/polars-core/src/chunked_array/builder/boolean.rs +++ b/crates/polars-core/src/chunked_array/builder/boolean.rs @@ -30,7 +30,7 @@ impl ChunkedBuilder for BooleanChunkedBuilder { } impl BooleanChunkedBuilder { - pub fn new(name: &str, capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { BooleanChunkedBuilder { array_builder: MutableBooleanArray::with_capacity(capacity), field: Field::new(name, DataType::Boolean), diff --git a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs index e235d08ffbd6..64cccf3b7f36 100644 --- a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs +++ b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs @@ -1,13 +1,13 @@ use arrow::types::NativeType; +use polars_utils::pl_str::PlSmallStr; use polars_utils::unwrap::UnwrapUncheckedRelease; -use smartstring::alias::String as SmartString; use crate::prelude::*; pub(crate) struct FixedSizeListNumericBuilder { inner: Option>>, width: usize, - name: SmartString, + name: PlSmallStr, logical_dtype: DataType, } @@ -16,7 +16,7 @@ impl FixedSizeListNumericBuilder { /// /// The caller must ensure that the physical numerical type match logical type. pub(crate) unsafe fn new( - name: &str, + name: PlSmallStr, width: usize, capacity: usize, logical_dtype: DataType, @@ -26,7 +26,7 @@ impl FixedSizeListNumericBuilder { Self { inner, width, - name: name.into(), + name, logical_dtype, } } @@ -77,7 +77,7 @@ impl FixedSizeListBuilder for FixedSizeListNumericBuilder { // SAFETY: physical type matches the logical unsafe { ChunkedArray::from_chunks_and_dtype( - self.name.as_str(), + self.name.clone(), vec![Box::new(arr)], DataType::Array(Box::new(self.logical_dtype.clone()), self.width), ) @@ -87,13 +87,13 @@ impl FixedSizeListBuilder for FixedSizeListNumericBuilder { pub(crate) struct AnonymousOwnedFixedSizeListBuilder { inner: fixed_size_list::AnonymousBuilder, - name: SmartString, + name: PlSmallStr, inner_dtype: Option, } impl AnonymousOwnedFixedSizeListBuilder { pub(crate) fn new( - name: &str, + name: PlSmallStr, width: usize, capacity: usize, inner_dtype: Option, @@ -101,7 +101,7 @@ impl AnonymousOwnedFixedSizeListBuilder { let inner = fixed_size_list::AnonymousBuilder::new(capacity, width); Self { inner, - name: name.into(), + name, inner_dtype, } } @@ -128,7 +128,7 @@ impl FixedSizeListBuilder for AnonymousOwnedFixedSizeListBuilder { .as_ref(), ) .unwrap(); - ChunkedArray::with_chunk(self.name.as_str(), arr) + ChunkedArray::with_chunk(self.name.clone(), arr) } } @@ -136,7 +136,7 @@ pub(crate) fn get_fixed_size_list_builder( inner_type_logical: &DataType, capacity: usize, width: usize, - name: &str, + name: PlSmallStr, ) -> PolarsResult> { let phys_dtype = inner_type_logical.to_physical(); diff --git a/crates/polars-core/src/chunked_array/builder/list/anonymous.rs b/crates/polars-core/src/chunked_array/builder/list/anonymous.rs index 99b566320fbf..80305ca043fb 100644 --- a/crates/polars-core/src/chunked_array/builder/list/anonymous.rs +++ b/crates/polars-core/src/chunked_array/builder/list/anonymous.rs @@ -1,7 +1,7 @@ use super::*; pub struct AnonymousListBuilder<'a> { - name: String, + name: PlSmallStr, builder: AnonymousBuilder<'a>, fast_explode: bool, inner_dtype: DtypeMerger, @@ -9,14 +9,14 @@ pub struct AnonymousListBuilder<'a> { impl Default for AnonymousListBuilder<'_> { fn default() -> Self { - Self::new("", 0, None) + Self::new(PlSmallStr::EMPTY, 0, None) } } impl<'a> AnonymousListBuilder<'a> { - pub fn new(name: &str, capacity: usize, inner_dtype: Option) -> Self { + pub fn new(name: PlSmallStr, capacity: usize, inner_dtype: Option) -> Self { Self { - name: name.into(), + name, builder: AnonymousBuilder::new(capacity), fast_explode: true, inner_dtype: DtypeMerger::new(inner_dtype), @@ -63,12 +63,6 @@ impl<'a> AnonymousListBuilder<'a> { // Empty arrays tend to be null type and thus differ // if we would push it the concat would fail. DataType::Null if s.is_empty() => self.append_empty(), - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { - let arr = &**s.array_ref(0); - self.builder.push(arr); - return Ok(()); - }, dt => self.inner_dtype.update(dt)?, } self.builder.push_multiple(s.chunks()); @@ -80,7 +74,7 @@ impl<'a> AnonymousListBuilder<'a> { let slf = std::mem::take(self); if slf.builder.is_empty() { ListChunked::full_null_with_dtype( - &slf.name, + slf.name.clone(), 0, &slf.inner_dtype.materialize().unwrap_or(DataType::Null), ) @@ -93,22 +87,22 @@ impl<'a> AnonymousListBuilder<'a> { let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap(); let list_dtype_logical = match inner_dtype { - None => DataType::from(arr.data_type()), + None => DataType::from(arr.dtype()), Some(dt) => DataType::List(Box::new(dt)), }; - let mut ca = ListChunked::with_chunk("", arr); + let mut ca = ListChunked::with_chunk(PlSmallStr::EMPTY, arr); if slf.fast_explode { ca.set_fast_explode(); } - ca.field = Arc::new(Field::new(&slf.name, list_dtype_logical)); + ca.field = Arc::new(Field::new(slf.name.clone(), list_dtype_logical)); ca } } } pub struct AnonymousOwnedListBuilder { - name: String, + name: PlSmallStr, builder: AnonymousBuilder<'static>, owned: Vec, inner_dtype: DtypeMerger, @@ -117,7 +111,7 @@ pub struct AnonymousOwnedListBuilder { impl Default for AnonymousOwnedListBuilder { fn default() -> Self { - Self::new("", 0, None) + Self::new(PlSmallStr::EMPTY, 0, None) } } @@ -127,17 +121,9 @@ impl ListBuilderTrait for AnonymousOwnedListBuilder { self.append_empty(); } else { unsafe { - match s.dtype() { - #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => { - self.builder.push(&*(&**s.array_ref(0) as *const dyn Array)); - }, - dt => { - self.inner_dtype.update(dt)?; - self.builder - .push_multiple(&*(s.chunks().as_ref() as *const [ArrayRef])); - }, - } + self.inner_dtype.update(s.dtype())?; + self.builder + .push_multiple(&*(s.chunks().as_ref() as *const [ArrayRef])); } // This make sure that the underlying ArrayRef's are not dropped. self.owned.push(s.clone()); @@ -161,23 +147,23 @@ impl ListBuilderTrait for AnonymousOwnedListBuilder { let arr = slf.builder.finish(inner_dtype_physical.as_ref()).unwrap(); let list_dtype_logical = match inner_dtype { - None => DataType::from_arrow(arr.data_type(), false), + None => DataType::from_arrow(arr.dtype(), false), Some(dt) => DataType::List(Box::new(dt)), }; - let mut ca = ListChunked::with_chunk("", arr); + let mut ca = ListChunked::with_chunk(PlSmallStr::EMPTY, arr); if slf.fast_explode { ca.set_fast_explode(); } - ca.field = Arc::new(Field::new(&slf.name, list_dtype_logical)); + ca.field = Arc::new(Field::new(slf.name.clone(), list_dtype_logical)); ca } } impl AnonymousOwnedListBuilder { - pub fn new(name: &str, capacity: usize, inner_dtype: Option) -> Self { + pub fn new(name: PlSmallStr, capacity: usize, inner_dtype: Option) -> Self { Self { - name: name.into(), + name, builder: AnonymousBuilder::new(capacity), owned: Vec::with_capacity(capacity), inner_dtype: DtypeMerger::new(inner_dtype), diff --git a/crates/polars-core/src/chunked_array/builder/list/binary.rs b/crates/polars-core/src/chunked_array/builder/list/binary.rs index 6382d9269f49..d55a69a2eace 100644 --- a/crates/polars-core/src/chunked_array/builder/list/binary.rs +++ b/crates/polars-core/src/chunked_array/builder/list/binary.rs @@ -7,7 +7,7 @@ pub struct ListStringChunkedBuilder { } impl ListStringChunkedBuilder { - pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize, values_capacity: usize) -> Self { let values = MutableBinaryViewArray::with_capacity(values_capacity); let builder = LargeListBinViewBuilder::new_with_capacity(values, capacity); let field = Field::new(name, DataType::List(Box::new(DataType::String))); @@ -97,7 +97,7 @@ pub struct ListBinaryChunkedBuilder { } impl ListBinaryChunkedBuilder { - pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize, values_capacity: usize) -> Self { let values = MutablePlBinary::with_capacity(values_capacity); let builder = LargeListBinViewBuilder::new_with_capacity(values, capacity); let field = Field::new(name, DataType::List(Box::new(DataType::Binary))); diff --git a/crates/polars-core/src/chunked_array/builder/list/boolean.rs b/crates/polars-core/src/chunked_array/builder/list/boolean.rs index 1d83a05ace00..8142d1a50954 100644 --- a/crates/polars-core/src/chunked_array/builder/list/boolean.rs +++ b/crates/polars-core/src/chunked_array/builder/list/boolean.rs @@ -7,7 +7,7 @@ pub struct ListBooleanChunkedBuilder { } impl ListBooleanChunkedBuilder { - pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize, values_capacity: usize) -> Self { let values = MutableBooleanArray::with_capacity(values_capacity); let builder = LargeListBooleanBuilder::new_with_capacity(values, capacity); let field = Field::new(name, DataType::List(Box::new(DataType::Boolean))); diff --git a/crates/polars-core/src/chunked_array/builder/list/categorical.rs b/crates/polars-core/src/chunked_array/builder/list/categorical.rs index 2807991b377d..3670e0ab3df9 100644 --- a/crates/polars-core/src/chunked_array/builder/list/categorical.rs +++ b/crates/polars-core/src/chunked_array/builder/list/categorical.rs @@ -1,9 +1,7 @@ -use ahash::RandomState; - use super::*; pub fn create_categorical_chunked_listbuilder( - name: &str, + name: PlSmallStr, ordering: CategoricalOrdering, capacity: usize, values_capacity: usize, @@ -35,7 +33,7 @@ pub struct ListEnumCategoricalChunkedBuilder { impl ListEnumCategoricalChunkedBuilder { pub(super) fn new( - name: &str, + name: PlSmallStr, ordering: CategoricalOrdering, capacity: usize, values_capacity: usize, @@ -88,12 +86,12 @@ struct KeyWrapper(u32); impl ListLocalCategoricalChunkedBuilder { #[inline] - pub fn get_hash_builder() -> RandomState { - RandomState::with_seed(0) + pub fn get_hash_builder() -> PlRandomState { + PlRandomState::with_seed(0) } pub(super) fn new( - name: &str, + name: PlSmallStr, ordering: CategoricalOrdering, capacity: usize, values_capacity: usize, @@ -208,7 +206,7 @@ struct ListGlobalCategoricalChunkedBuilder { impl ListGlobalCategoricalChunkedBuilder { pub(super) fn new( - name: &str, + name: PlSmallStr, ordering: CategoricalOrdering, capacity: usize, values_capacity: usize, diff --git a/crates/polars-core/src/chunked_array/builder/list/mod.rs b/crates/polars-core/src/chunked_array/builder/list/mod.rs index 9a7f9243dcf3..645a2a168e90 100644 --- a/crates/polars-core/src/chunked_array/builder/list/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/list/mod.rs @@ -84,7 +84,7 @@ pub fn get_list_builder( inner_type_logical: &DataType, value_capacity: usize, list_capacity: usize, - name: &str, + name: PlSmallStr, ) -> PolarsResult> { match inner_type_logical { #[cfg(feature = "dtype-categorical")] @@ -159,21 +159,21 @@ pub fn get_list_builder( macro_rules! get_bool_builder { () => {{ let builder = - ListBooleanChunkedBuilder::new(&name, list_capacity, value_capacity); + ListBooleanChunkedBuilder::new(name, list_capacity, value_capacity); Box::new(builder) }}; } macro_rules! get_string_builder { () => {{ let builder = - ListStringChunkedBuilder::new(&name, list_capacity, 5 * value_capacity); + ListStringChunkedBuilder::new(name, list_capacity, 5 * value_capacity); Box::new(builder) }}; } macro_rules! get_binary_builder { () => {{ let builder = - ListBinaryChunkedBuilder::new(&name, list_capacity, 5 * value_capacity); + ListBinaryChunkedBuilder::new(name, list_capacity, 5 * value_capacity); Box::new(builder) }}; } diff --git a/crates/polars-core/src/chunked_array/builder/list/null.rs b/crates/polars-core/src/chunked_array/builder/list/null.rs index ab6e7a73ec7b..233f53e17412 100644 --- a/crates/polars-core/src/chunked_array/builder/list/null.rs +++ b/crates/polars-core/src/chunked_array/builder/list/null.rs @@ -2,14 +2,14 @@ use super::*; pub struct ListNullChunkedBuilder { builder: LargeListNullBuilder, - name: String, + name: PlSmallStr, } impl ListNullChunkedBuilder { - pub fn new(name: &str, capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { ListNullChunkedBuilder { builder: LargeListNullBuilder::with_capacity(capacity), - name: name.into(), + name, } } @@ -41,7 +41,7 @@ impl ListBuilderTrait for ListNullChunkedBuilder { fn finish(&mut self) -> ListChunked { unsafe { ListChunked::from_chunks_and_dtype_unchecked( - &self.name, + self.name.clone(), vec![self.builder.as_box()], DataType::List(Box::new(DataType::Null)), ) diff --git a/crates/polars-core/src/chunked_array/builder/list/primitive.rs b/crates/polars-core/src/chunked_array/builder/list/primitive.rs index d9555716d45d..0b1de987efb4 100644 --- a/crates/polars-core/src/chunked_array/builder/list/primitive.rs +++ b/crates/polars-core/src/chunked_array/builder/list/primitive.rs @@ -14,7 +14,7 @@ where T: PolarsNumericType, { pub fn new( - name: &str, + name: PlSmallStr, capacity: usize, values_capacity: usize, logical_type: DataType, @@ -31,7 +31,7 @@ where } pub fn new_with_values_type( - name: &str, + name: PlSmallStr, capacity: usize, values_capacity: usize, values_type: DataType, diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index bac88f5a1ea5..539586c2193e 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -46,36 +46,36 @@ where let chunks = iter .into_iter() .map(|(values, opt_buffer)| to_primitive::(values, opt_buffer)); - ChunkedArray::from_chunk_iter("from_iter", chunks) + ChunkedArray::from_chunk_iter(PlSmallStr::EMPTY, chunks) } } pub trait NewChunkedArray { - fn from_slice(name: &str, v: &[N]) -> Self; - fn from_slice_options(name: &str, opt_v: &[Option]) -> Self; + fn from_slice(name: PlSmallStr, v: &[N]) -> Self; + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self; /// Create a new ChunkedArray from an iterator. - fn from_iter_options(name: &str, it: impl Iterator>) -> Self; + fn from_iter_options(name: PlSmallStr, it: impl Iterator>) -> Self; /// Create a new ChunkedArray from an iterator. - fn from_iter_values(name: &str, it: impl Iterator) -> Self; + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> Self; } impl NewChunkedArray for ChunkedArray where T: PolarsNumericType, { - fn from_slice(name: &str, v: &[T::Native]) -> Self { + fn from_slice(name: PlSmallStr, v: &[T::Native]) -> Self { let arr = PrimitiveArray::from_slice(v).to(T::get_dtype().to_arrow(CompatLevel::newest())); ChunkedArray::with_chunk(name, arr) } - fn from_slice_options(name: &str, opt_v: &[Option]) -> Self { + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { Self::from_iter_options(name, opt_v.iter().copied()) } fn from_iter_options( - name: &str, + name: PlSmallStr, it: impl Iterator>, ) -> ChunkedArray { let mut builder = PrimitiveChunkedBuilder::new(name, get_iter_capacity(&it)); @@ -84,7 +84,7 @@ where } /// Create a new ChunkedArray from an iterator. - fn from_iter_values(name: &str, it: impl Iterator) -> ChunkedArray { + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> ChunkedArray { let ca: NoNull> = it.collect(); let mut ca = ca.into_inner(); ca.rename(name); @@ -93,16 +93,16 @@ where } impl NewChunkedArray for BooleanChunked { - fn from_slice(name: &str, v: &[bool]) -> Self { + fn from_slice(name: PlSmallStr, v: &[bool]) -> Self { Self::from_iter_values(name, v.iter().copied()) } - fn from_slice_options(name: &str, opt_v: &[Option]) -> Self { + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { Self::from_iter_options(name, opt_v.iter().copied()) } fn from_iter_options( - name: &str, + name: PlSmallStr, it: impl Iterator>, ) -> ChunkedArray { let mut builder = BooleanChunkedBuilder::new(name, get_iter_capacity(&it)); @@ -111,7 +111,10 @@ impl NewChunkedArray for BooleanChunked { } /// Create a new ChunkedArray from an iterator. - fn from_iter_values(name: &str, it: impl Iterator) -> ChunkedArray { + fn from_iter_values( + name: PlSmallStr, + it: impl Iterator, + ) -> ChunkedArray { let mut ca: ChunkedArray<_> = it.collect(); ca.rename(name); ca @@ -122,23 +125,23 @@ impl NewChunkedArray for StringChunked where S: AsRef, { - fn from_slice(name: &str, v: &[S]) -> Self { + fn from_slice(name: PlSmallStr, v: &[S]) -> Self { let arr = Utf8ViewArray::from_slice_values(v); ChunkedArray::with_chunk(name, arr) } - fn from_slice_options(name: &str, opt_v: &[Option]) -> Self { + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { let arr = Utf8ViewArray::from_slice(opt_v); ChunkedArray::with_chunk(name, arr) } - fn from_iter_options(name: &str, it: impl Iterator>) -> Self { + fn from_iter_options(name: PlSmallStr, it: impl Iterator>) -> Self { let arr = MutableBinaryViewArray::from_iterator(it).freeze(); ChunkedArray::with_chunk(name, arr) } /// Create a new ChunkedArray from an iterator. - fn from_iter_values(name: &str, it: impl Iterator) -> Self { + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> Self { let arr = MutableBinaryViewArray::from_values_iter(it).freeze(); ChunkedArray::with_chunk(name, arr) } @@ -148,23 +151,23 @@ impl NewChunkedArray for BinaryChunked where B: AsRef<[u8]>, { - fn from_slice(name: &str, v: &[B]) -> Self { + fn from_slice(name: PlSmallStr, v: &[B]) -> Self { let arr = BinaryViewArray::from_slice_values(v); ChunkedArray::with_chunk(name, arr) } - fn from_slice_options(name: &str, opt_v: &[Option]) -> Self { + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { let arr = BinaryViewArray::from_slice(opt_v); ChunkedArray::with_chunk(name, arr) } - fn from_iter_options(name: &str, it: impl Iterator>) -> Self { + fn from_iter_options(name: PlSmallStr, it: impl Iterator>) -> Self { let arr = MutableBinaryViewArray::from_iterator(it).freeze(); ChunkedArray::with_chunk(name, arr) } /// Create a new ChunkedArray from an iterator. - fn from_iter_values(name: &str, it: impl Iterator) -> Self { + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> Self { let arr = MutableBinaryViewArray::from_values_iter(it).freeze(); ChunkedArray::with_chunk(name, arr) } @@ -176,7 +179,8 @@ mod test { #[test] fn test_primitive_builder() { - let mut builder = PrimitiveChunkedBuilder::::new("foo", 6); + let mut builder = + PrimitiveChunkedBuilder::::new(PlSmallStr::from_static("foo"), 6); let values = &[Some(1), None, Some(2), Some(3), None, Some(4)]; for val in values { builder.append_option(*val); @@ -187,12 +191,17 @@ mod test { #[test] fn test_list_builder() { - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 5, DataType::Int32); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 5, + DataType::Int32, + ); // Create a series containing two chunks. - let mut s1 = Int32Chunked::from_slice("a", &[1, 2, 3]).into_series(); - let s2 = Int32Chunked::from_slice("b", &[4, 5, 6]).into_series(); + let mut s1 = + Int32Chunked::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3]).into_series(); + let s2 = Int32Chunked::from_slice(PlSmallStr::from_static("b"), &[4, 5, 6]).into_series(); s1.append(&s2).unwrap(); builder.append_series(&s1).unwrap(); @@ -215,8 +224,12 @@ mod test { assert_eq!(out.get_as_series(0).unwrap().len(), 6); assert_eq!(out.get_as_series(1).unwrap().len(), 3); - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 5, DataType::Int32); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 5, + DataType::Int32, + ); builder.append_series(&s1).unwrap(); builder.append_null(); diff --git a/crates/polars-core/src/chunked_array/builder/null.rs b/crates/polars-core/src/chunked_array/builder/null.rs index 8e4d5b9cb107..f4101a2a14e7 100644 --- a/crates/polars-core/src/chunked_array/builder/null.rs +++ b/crates/polars-core/src/chunked_array/builder/null.rs @@ -10,7 +10,7 @@ pub struct NullChunkedBuilder { } impl NullChunkedBuilder { - pub fn new(name: &str, len: usize) -> Self { + pub fn new(name: PlSmallStr, len: usize) -> Self { let array_builder = MutableNullArray::new(len); NullChunkedBuilder { @@ -27,7 +27,7 @@ impl NullChunkedBuilder { pub fn finish(mut self) -> NullChunked { let arr = self.array_builder.as_box(); - let ca = NullChunked::new(Arc::from(self.field.name.as_str()), arr.len()); + let ca = NullChunked::new(self.field.name().clone(), arr.len()); ca } diff --git a/crates/polars-core/src/chunked_array/builder/primitive.rs b/crates/polars-core/src/chunked_array/builder/primitive.rs index 14eb2c1f4f46..f310d4145a19 100644 --- a/crates/polars-core/src/chunked_array/builder/primitive.rs +++ b/crates/polars-core/src/chunked_array/builder/primitive.rs @@ -39,7 +39,7 @@ impl PrimitiveChunkedBuilder where T: PolarsNumericType, { - pub fn new(name: &str, capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { let array_builder = MutablePrimitiveArray::::with_capacity(capacity) .to(T::get_dtype().to_arrow(CompatLevel::newest())); diff --git a/crates/polars-core/src/chunked_array/builder/string.rs b/crates/polars-core/src/chunked_array/builder/string.rs index 36c1d90492bc..8375760c606d 100644 --- a/crates/polars-core/src/chunked_array/builder/string.rs +++ b/crates/polars-core/src/chunked_array/builder/string.rs @@ -18,13 +18,12 @@ pub type StringChunkedBuilder = BinViewChunkedBuilder; pub type BinaryChunkedBuilder = BinViewChunkedBuilder<[u8]>; impl BinViewChunkedBuilder { - /// Create a new StringChunkedBuilder + /// Create a new BinViewChunkedBuilder /// /// # Arguments /// /// * `capacity` - Number of string elements in the final array. - /// * `bytes_capacity` - Number of bytes needed to store the string values. - pub fn new(name: &str, capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { Self { chunk_builder: MutableBinaryViewArray::with_capacity(capacity), field: Arc::new(Field::new(name, DataType::from(&T::DATA_TYPE))), diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index acbfb0839807..53f6e85f221d 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -70,7 +70,7 @@ pub(crate) fn cast_chunks( } fn cast_impl_inner( - name: &str, + name: PlSmallStr, chunks: &[ArrayRef], dtype: &DataType, options: CastOptions, @@ -98,7 +98,7 @@ fn cast_impl_inner( } fn cast_impl( - name: &str, + name: PlSmallStr, chunks: &[ArrayRef], dtype: &DataType, options: CastOptions, @@ -108,7 +108,7 @@ fn cast_impl( #[cfg(feature = "dtype-struct")] fn cast_single_to_struct( - name: &str, + name: PlSmallStr, chunks: &[ArrayRef], fields: &[Field], options: CastOptions, @@ -117,12 +117,12 @@ fn cast_single_to_struct( // cast to first field dtype let mut fields = fields.iter(); let fld = fields.next().unwrap(); - let s = cast_impl_inner(&fld.name, chunks, &fld.dtype, options)?; + let s = cast_impl_inner(fld.name.clone(), chunks, &fld.dtype, options)?; let length = s.len(); new_fields.push(s); for fld in fields { - new_fields.push(Series::full_null(&fld.name, length, &fld.dtype)); + new_fields.push(Series::full_null(fld.name.clone(), length, &fld.dtype)); } StructChunked::from_series(name, &new_fields).map(|ca| ca.into_series()) @@ -132,16 +132,20 @@ impl ChunkedArray where T: PolarsNumericType, { - fn cast_impl(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - if self.dtype() == data_type { + fn cast_impl(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + if self.dtype() == dtype { // SAFETY: chunks are correct dtype let mut out = unsafe { - Series::from_chunks_and_dtype_unchecked(self.name(), self.chunks.clone(), data_type) + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + self.chunks.clone(), + dtype, + ) }; out.set_sorted_flag(self.is_sorted_flag()); return Ok(out); } - match data_type { + match dtype { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_, ordering) => { polars_ensure!( @@ -195,24 +199,23 @@ where }, #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { - cast_single_to_struct(self.name(), &self.chunks, fields, options) + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) }, - _ => cast_impl_inner(self.name(), &self.chunks, data_type, options).map(|mut s| { + _ => cast_impl_inner(self.name().clone(), &self.chunks, dtype, options).map(|mut s| { // maintain sorted if data types // - remain signed // - unsigned -> signed // this may still fail with overflow? let dtype = self.dtype(); - let to_signed = data_type.is_signed_integer(); - let unsigned2unsigned = - dtype.is_unsigned_integer() && data_type.is_unsigned_integer(); + let to_signed = dtype.is_signed_integer(); + let unsigned2unsigned = dtype.is_unsigned_integer() && dtype.is_unsigned_integer(); let allowed = to_signed || unsigned2unsigned; if (allowed) && (s.null_count() == self.null_count()) // physical to logicals - || (self.dtype().to_physical() == data_type.to_physical()) + || (self.dtype().to_physical() == dtype.to_physical()) { let is_sorted = self.is_sorted_flag(); s.set_sorted_flag(is_sorted) @@ -227,16 +230,12 @@ impl ChunkCast for ChunkedArray where T: PolarsNumericType, { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { - self.cast_impl(data_type, options) + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.cast_impl(dtype, options) } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - match data_type { + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + match dtype { #[cfg(feature = "dtype-categorical")] DataType::Categorical(Some(rev_map), ordering) | DataType::Enum(Some(rev_map), ordering) => { @@ -248,7 +247,7 @@ where CategoricalChunked::from_cats_and_rev_map_unchecked( ca.clone(), rev_map.clone(), - matches!(data_type, DataType::Enum(_, _)), + matches!(dtype, DataType::Enum(_, _)), *ordering, ) } @@ -257,18 +256,14 @@ where polars_bail!(ComputeError: "cannot cast numeric types to 'Categorical'"); } }, - _ => self.cast_impl(data_type, CastOptions::Overflowing), + _ => self.cast_impl(dtype, CastOptions::Overflowing), } } } impl ChunkCast for StringChunked { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { - match data_type { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { #[cfg(feature = "dtype-categorical")] DataType::Categorical(rev_map, ordering) => match rev_map { None => { @@ -276,7 +271,7 @@ impl ChunkCast for StringChunked { let iter = unsafe { self.downcast_iter().flatten().trust_my_length(self.len()) }; let builder = - CategoricalChunkedBuilder::new(self.name(), self.len(), *ordering); + CategoricalChunkedBuilder::new(self.name().clone(), self.len(), *ordering); let ca = builder.drain_iter_and_finish(iter); Ok(ca.into_series()) }, @@ -292,13 +287,13 @@ impl ChunkCast for StringChunked { CategoricalChunked::from_string_to_enum(self, rev_map.get_categories(), *ordering) .map(|ca| { let mut s = ca.into_series(); - s.rename(self.name()); + s.rename(self.name().clone()); s }) }, #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { - cast_single_to_struct(self.name(), &self.chunks, fields, options) + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, scale) => match (precision, scale) { @@ -310,7 +305,7 @@ impl ChunkCast for StringChunked { *scale, ) }); - Ok(Int128Chunked::from_chunk_iter(self.name(), chunks) + Ok(Int128Chunked::from_chunk_iter(self.name().clone(), chunks) .into_decimal_unchecked(*precision, *scale) .into_series()) }, @@ -321,8 +316,8 @@ impl ChunkCast for StringChunked { }, #[cfg(feature = "dtype-date")] DataType::Date => { - let result = cast_chunks(&self.chunks, data_type, options)?; - let out = Series::try_from((self.name(), result))?; + let result = cast_chunks(&self.chunks, dtype, options)?; + let out = Series::try_from((self.name().clone(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] @@ -336,7 +331,7 @@ impl ChunkCast for StringChunked { &Datetime(time_unit.to_owned(), Some(time_zone.clone())), options, )?; - Series::try_from((self.name(), result)) + Series::try_from((self.name().clone(), result)) }, _ => { let result = cast_chunks( @@ -344,17 +339,17 @@ impl ChunkCast for StringChunked { &Datetime(time_unit.to_owned(), None), options, )?; - Series::try_from((self.name(), result)) + Series::try_from((self.name().clone(), result)) }, }; out }, - _ => cast_impl(self.name(), &self.chunks, data_type, options), + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), } } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast_with_options(data_type, CastOptions::Overflowing) + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) } } @@ -366,7 +361,7 @@ impl BinaryChunked { .downcast_iter() .map(|arr| arr.to_utf8view_unchecked().boxed()) .collect(); - let field = Arc::new(Field::new(self.name(), DataType::String)); + let field = Arc::new(Field::new(self.name().clone(), DataType::String)); let mut ca = StringChunked::new_with_compute_len(field, chunks); @@ -383,7 +378,7 @@ impl StringChunked { .downcast_iter() .map(|arr| arr.to_binview().boxed()) .collect(); - let field = Arc::new(Field::new(self.name(), DataType::Binary)); + let field = Arc::new(Field::new(self.name().clone(), DataType::Binary)); let mut ca = BinaryChunked::new_with_compute_len(field, chunks); @@ -395,78 +390,62 @@ impl StringChunked { } impl ChunkCast for BinaryChunked { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { - match data_type { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { - cast_single_to_struct(self.name(), &self.chunks, fields, options) + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) }, - _ => cast_impl(self.name(), &self.chunks, data_type, options), + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), } } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - match data_type { + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + match dtype { DataType::String => unsafe { Ok(self.to_string_unchecked().into_series()) }, - _ => self.cast_with_options(data_type, CastOptions::Overflowing), + _ => self.cast_with_options(dtype, CastOptions::Overflowing), } } } impl ChunkCast for BinaryOffsetChunked { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { - match data_type { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { - cast_single_to_struct(self.name(), &self.chunks, fields, options) + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) }, - _ => cast_impl(self.name(), &self.chunks, data_type, options), + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), } } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast_with_options(data_type, CastOptions::Overflowing) + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) } } impl ChunkCast for BooleanChunked { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { - match data_type { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + match dtype { #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => { - cast_single_to_struct(self.name(), &self.chunks, fields, options) + cast_single_to_struct(self.name().clone(), &self.chunks, fields, options) }, - _ => cast_impl(self.name(), &self.chunks, data_type, options), + _ => cast_impl(self.name().clone(), &self.chunks, dtype, options), } } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast_with_options(data_type, CastOptions::Overflowing) + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) } } /// We cannot cast anything to or from List/LargeList /// So this implementation casts the inner type impl ChunkCast for ListChunked { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { use DataType::*; - match data_type { + match dtype { List(child_type) => { match (self.inner_dtype(), &**child_type) { (old, new) if old == new => Ok(self.clone().into_series()), @@ -483,7 +462,7 @@ impl ChunkCast for ListChunked { // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![arr], &List(Box::new(child_type)), )) @@ -493,7 +472,7 @@ impl ChunkCast for ListChunked { }, #[cfg(feature = "dtype-array")] Array(child_type, width) => { - let physical_type = data_type.to_physical(); + let physical_type = dtype.to_physical(); // TODO!: properly implement this recursively. #[cfg(feature = "dtype-categorical")] @@ -505,7 +484,7 @@ impl ChunkCast for ListChunked { // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), chunks, &Array(child_type.clone(), *width), )) @@ -515,17 +494,17 @@ impl ChunkCast for ListChunked { polars_bail!( InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')", self.inner_dtype(), - data_type, + dtype, ) }, } } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { use DataType::*; - match data_type { + match dtype { List(child_type) => cast_list_unchecked(self, child_type), - _ => self.cast_with_options(data_type, CastOptions::Overflowing), + _ => self.cast_with_options(dtype, CastOptions::Overflowing), } } } @@ -534,13 +513,9 @@ impl ChunkCast for ListChunked { /// So this implementation casts the inner type #[cfg(feature = "dtype-array")] impl ChunkCast for ArrayChunked { - fn cast_with_options( - &self, - data_type: &DataType, - options: CastOptions, - ) -> PolarsResult { + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { use DataType::*; - match data_type { + match dtype { Array(child_type, width) => { polars_ensure!( *width == self.width(), @@ -560,7 +535,7 @@ impl ChunkCast for ArrayChunked { // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![arr], &Array(Box::new(child_type), *width), )) @@ -569,14 +544,14 @@ impl ChunkCast for ArrayChunked { } }, List(child_type) => { - let physical_type = data_type.to_physical(); + let physical_type = dtype.to_physical(); // cast to the physical type to avoid logical chunks. let chunks = cast_chunks(self.chunks(), &physical_type, options)?; // SAFETY: we just casted so the dtype matches. // we must take this path to correct for physical types. unsafe { Ok(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), chunks, &List(child_type.clone()), )) @@ -586,14 +561,14 @@ impl ChunkCast for ArrayChunked { polars_bail!( InvalidOperation: "cannot cast Array type (inner: '{:?}', to: '{:?}')", self.inner_dtype(), - data_type, + dtype, ) }, } } - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast_with_options(data_type, CastOptions::Overflowing) + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::Overflowing) } } @@ -610,7 +585,11 @@ fn cast_list( let arr = ca.downcast_iter().next().unwrap(); // SAFETY: inner dtype is passed correctly let s = unsafe { - Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], ca.inner_dtype()) + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + ca.inner_dtype(), + ) }; let new_inner = s.cast_with_options(child_type, options)?; @@ -619,9 +598,9 @@ fn cast_list( let new_values = new_inner.array_ref(0).clone(); - let data_type = ListArray::::default_datatype(new_values.data_type().clone()); + let dtype = ListArray::::default_datatype(new_values.dtype().clone()); let new_arr = ListArray::::new( - data_type, + dtype, arr.offsets().clone(), new_values, arr.validity().cloned(), @@ -635,20 +614,24 @@ unsafe fn cast_list_unchecked(ca: &ListChunked, child_type: &DataType) -> Polars let arr = ca.downcast_iter().next().unwrap(); // SAFETY: inner dtype is passed correctly let s = unsafe { - Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], ca.inner_dtype()) + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + ca.inner_dtype(), + ) }; let new_inner = s.cast_unchecked(child_type)?; let new_values = new_inner.array_ref(0).clone(); - let data_type = ListArray::::default_datatype(new_values.data_type().clone()); + let dtype = ListArray::::default_datatype(new_values.dtype().clone()); let new_arr = ListArray::::new( - data_type, + dtype, arr.offsets().clone(), new_values, arr.validity().cloned(), ); Ok(ListChunked::from_chunks_and_dtype_unchecked( - ca.name(), + ca.name().clone(), vec![Box::new(new_arr)], DataType::List(Box::new(child_type.clone())), ) @@ -667,7 +650,11 @@ fn cast_fixed_size_list( let arr = ca.downcast_iter().next().unwrap(); // SAFETY: inner dtype is passed correctly let s = unsafe { - Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], ca.inner_dtype()) + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + ca.inner_dtype(), + ) }; let new_inner = s.cast_with_options(child_type, options)?; @@ -676,9 +663,8 @@ fn cast_fixed_size_list( let new_values = new_inner.array_ref(0).clone(); - let data_type = - FixedSizeListArray::default_datatype(new_values.data_type().clone(), ca.width()); - let new_arr = FixedSizeListArray::new(data_type, new_values, arr.validity().cloned()); + let dtype = FixedSizeListArray::default_datatype(new_values.dtype().clone(), ca.width()); + let new_arr = FixedSizeListArray::new(dtype, new_values, arr.validity().cloned()); Ok((Box::new(new_arr), inner_dtype)) } @@ -689,8 +675,12 @@ mod test { #[test] fn test_cast_list() -> PolarsResult<()> { - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); builder.append_opt_slice(Some(&[1i32, 2, 3])); builder.append_opt_slice(Some(&[1i32, 2, 3])); let ca = builder.finish(); @@ -708,7 +698,7 @@ mod test { #[cfg(feature = "dtype-categorical")] fn test_cast_noop() { // check if we can cast categorical twice without panic - let ca = StringChunked::new("foo", &["bar", "ham"]); + let ca = StringChunked::new(PlSmallStr::from_static("foo"), &["bar", "ham"]); let out = ca .cast_with_options( &DataType::Categorical(None, Default::default()), diff --git a/crates/polars-core/src/chunked_array/collect.rs b/crates/polars-core/src/chunked_array/collect.rs index 054f59de8958..eb882b1acc13 100644 --- a/crates/polars-core/src/chunked_array/collect.rs +++ b/crates/polars-core/src/chunked_array/collect.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use arrow::trusted_len::TrustedLen; +use polars_utils::pl_str::PlSmallStr; use crate::chunked_array::ChunkedArray; use crate::datatypes::{ @@ -22,7 +23,7 @@ use crate::prelude::CompatLevel; pub trait ChunkedCollectIterExt: Iterator + Sized { #[inline] - fn collect_ca_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray + fn collect_ca_with_dtype(self, name: PlSmallStr, dtype: DataType) -> ChunkedArray where T::Array: ArrayFromIterDtype, { @@ -42,7 +43,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { } #[inline] - fn collect_ca_trusted_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray + fn collect_ca_trusted_with_dtype(self, name: PlSmallStr, dtype: DataType) -> ChunkedArray where T::Array: ArrayFromIterDtype, Self: TrustedLen, @@ -66,7 +67,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { #[inline] fn try_collect_ca_with_dtype( self, - name: &str, + name: PlSmallStr, dtype: DataType, ) -> Result, E> where @@ -95,7 +96,7 @@ pub trait ChunkedCollectIterExt: Iterator + Sized { #[inline] fn try_collect_ca_trusted_with_dtype( self, - name: &str, + name: PlSmallStr, dtype: DataType, ) -> Result, E> where @@ -128,7 +129,7 @@ impl ChunkedCollectIterExt for I {} pub trait ChunkedCollectInferIterExt: Iterator + Sized { #[inline] - fn collect_ca(self, name: &str) -> ChunkedArray + fn collect_ca(self, name: PlSmallStr) -> ChunkedArray where T::Array: ArrayFromIter, { @@ -138,7 +139,7 @@ pub trait ChunkedCollectInferIterExt: Iterator + Sized { } #[inline] - fn collect_ca_trusted(self, name: &str) -> ChunkedArray + fn collect_ca_trusted(self, name: PlSmallStr) -> ChunkedArray where T::Array: ArrayFromIter, Self: TrustedLen, @@ -149,7 +150,7 @@ pub trait ChunkedCollectInferIterExt: Iterator + Sized { } #[inline] - fn try_collect_ca(self, name: &str) -> Result, E> + fn try_collect_ca(self, name: PlSmallStr) -> Result, E> where T::Array: ArrayFromIter, Self: Iterator>, @@ -160,7 +161,7 @@ pub trait ChunkedCollectInferIterExt: Iterator + Sized { } #[inline] - fn try_collect_ca_trusted(self, name: &str) -> Result, E> + fn try_collect_ca_trusted(self, name: PlSmallStr) -> Result, E> where T::Array: ArrayFromIter, Self: Iterator> + TrustedLen, diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index 77ddf45a5a69..faa7f619cdb2 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -57,13 +57,13 @@ where .map(|phys| rev_map_r.get_unchecked(phys)) }; let Some(v) = v else { - return Ok(BooleanChunked::full_null(lhs.name(), lhs_len)); + return Ok(BooleanChunked::full_null(lhs.name().clone(), lhs_len)); }; Ok(lhs .iter_str() .map(|opt_s| opt_s.map(|s| compare_str_function(s, v))) - .collect_ca_trusted(lhs.name())) + .collect_ca_trusted(lhs.name().clone())) }, (1, rhs_len) => { // SAFETY: physical is in range of revmap @@ -73,12 +73,12 @@ where .map(|phys| rev_map_l.get_unchecked(phys)) }; let Some(v) = v else { - return Ok(BooleanChunked::full_null(lhs.name(), rhs_len)); + return Ok(BooleanChunked::full_null(lhs.name().clone(), rhs_len)); }; Ok(rhs .iter_str() .map(|opt_s| opt_s.map(|s| compare_str_function(v, s))) - .collect_ca_trusted(lhs.name())) + .collect_ca_trusted(lhs.name().clone())) }, (lhs_len, rhs_len) if lhs_len == rhs_len => Ok(lhs .iter_str() @@ -88,7 +88,7 @@ where (_, None) => None, (Some(l), Some(r)) => Some(compare_str_function(l, r)), }) - .collect_ca_trusted(lhs.name())), + .collect_ca_trusted(lhs.name().clone())), (lhs_len, rhs_len) => { polars_bail!(ComputeError: "Columns are of unequal length: {} vs {}",lhs_len,rhs_len) }, @@ -103,7 +103,7 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { cat_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, false), + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), UInt32Chunked::equal, ) } @@ -112,7 +112,7 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { cat_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name(), false, lhs.len()), + |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), UInt32Chunked::equal_missing, ) } @@ -121,7 +121,7 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { cat_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, true), + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), UInt32Chunked::not_equal, ) } @@ -130,7 +130,7 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked { cat_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name(), true, lhs.len()), + |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), UInt32Chunked::not_equal_missing, ) } @@ -203,7 +203,7 @@ where cat_compare_function(lhs, rhs_cat.categorical().unwrap()) } else if rhs.len() == 1 { match rhs.get(0) { - None => Ok(BooleanChunked::full_null(lhs.name(), lhs.len())), + None => Ok(BooleanChunked::full_null(lhs.name().clone(), lhs.len())), Some(s) => cat_single_str_compare_helper( lhs, s, @@ -224,8 +224,8 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { cat_str_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, false), - |lhs| BooleanChunked::full_null(lhs.name(), lhs.len()), + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), + |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()), |s1, s2| CategoricalChunked::equal(s1, s2), UInt32Chunked::equal, StringChunked::equal, @@ -235,7 +235,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { cat_str_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name(), false, lhs.len()), + |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), |lhs| lhs.physical().is_null(), |s1, s2| CategoricalChunked::equal_missing(s1, s2), UInt32Chunked::equal_missing, @@ -247,8 +247,8 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { cat_str_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, true), - |lhs| BooleanChunked::full_null(lhs.name(), lhs.len()), + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), + |lhs| BooleanChunked::full_null(lhs.name().clone(), lhs.len()), |s1, s2| CategoricalChunked::not_equal(s1, s2), UInt32Chunked::not_equal, StringChunked::not_equal, @@ -258,7 +258,7 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked { cat_str_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name(), true, lhs.len()), + |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), |lhs| !lhs.physical().is_null(), |s1, s2| CategoricalChunked::not_equal_missing(s1, s2), UInt32Chunked::not_equal_missing, @@ -371,7 +371,7 @@ where // SAFETY: indexing into bitmap with same length as original array opt_idx.map(|idx| unsafe { bitmap.get_bit_unchecked(idx as usize) }) })) - .with_name(lhs.name()), + .with_name(lhs.name().clone()), ) } } @@ -383,7 +383,7 @@ impl ChunkCompare<&str> for CategoricalChunked { cat_single_str_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, false), + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, false), UInt32Chunked::equal, ) } @@ -392,7 +392,7 @@ impl ChunkCompare<&str> for CategoricalChunked { cat_single_str_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name(), false, lhs.len()), + |lhs| BooleanChunked::full(lhs.name().clone(), false, lhs.len()), UInt32Chunked::equal_missing, ) } @@ -401,7 +401,7 @@ impl ChunkCompare<&str> for CategoricalChunked { cat_single_str_equality_helper( self, rhs, - |lhs| replace_non_null(lhs.name(), &lhs.physical().chunks, true), + |lhs| replace_non_null(lhs.name().clone(), &lhs.physical().chunks, true), UInt32Chunked::not_equal, ) } @@ -410,7 +410,7 @@ impl ChunkCompare<&str> for CategoricalChunked { cat_single_str_equality_helper( self, rhs, - |lhs| BooleanChunked::full(lhs.name(), true, lhs.len()), + |lhs| BooleanChunked::full(lhs.name().clone(), true, lhs.len()), UInt32Chunked::equal_missing, ) } diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 8878aa07f11d..300f5f338cff 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -30,17 +30,22 @@ where if let Some(value) = rhs.get(0) { self.equal(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.equal(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_eq_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -65,7 +70,7 @@ where self, rhs, |a, b| a.tot_eq_missing_kernel(b).into(), - "", + PlSmallStr::EMPTY, ), } } @@ -77,17 +82,22 @@ where if let Some(value) = rhs.get(0) { self.not_equal(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.not_equal(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_ne_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -112,7 +122,7 @@ where self, rhs, |a, b| a.tot_ne_missing_kernel(b).into(), - "", + PlSmallStr::EMPTY, ), } } @@ -124,17 +134,22 @@ where if let Some(value) = rhs.get(0) { self.lt(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.gt(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_lt_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_lt_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -145,17 +160,22 @@ where if let Some(value) = rhs.get(0) { self.lt_eq(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.gt_eq(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_le_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_le_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -172,35 +192,35 @@ impl ChunkCompare<&NullChunked> for NullChunked { type Item = BooleanChunked; fn equal(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) } fn equal_missing(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full(self.name(), true, get_broadcast_length(self, rhs)) + BooleanChunked::full(self.name().clone(), true, get_broadcast_length(self, rhs)) } fn not_equal(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) } fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full(self.name(), false, get_broadcast_length(self, rhs)) + BooleanChunked::full(self.name().clone(), false, get_broadcast_length(self, rhs)) } fn gt(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) } fn gt_eq(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) } fn lt(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) } fn lt_eq(&self, rhs: &NullChunked) -> Self::Item { - BooleanChunked::full_null(self.name(), get_broadcast_length(self, rhs)) + BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs)) } } @@ -224,17 +244,22 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { if let Some(value) = rhs.get(0) { arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { arity::unary_mut_values(rhs, |arr| arr.tot_eq_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_eq_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -263,7 +288,7 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { self, rhs, |a, b| a.tot_eq_missing_kernel(b).into(), - "", + PlSmallStr::EMPTY, ), } } @@ -275,17 +300,22 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { if let Some(value) = rhs.get(0) { arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { arity::unary_mut_values(rhs, |arr| arr.tot_ne_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_ne_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -314,7 +344,7 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { self, rhs, |a, b| a.tot_ne_missing_kernel(b).into(), - "", + PlSmallStr::EMPTY, ), } } @@ -326,17 +356,22 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { if let Some(value) = rhs.get(0) { arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { arity::unary_mut_values(rhs, |arr| arr.tot_gt_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_lt_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_lt_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -347,17 +382,22 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked { if let Some(value) = rhs.get(0) { arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { arity::unary_mut_values(rhs, |arr| arr.tot_ge_kernel_broadcast(&value).into()) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_le_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_le_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -415,17 +455,22 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { if let Some(value) = rhs.get(0) { self.equal(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.equal(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_eq_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -450,7 +495,7 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { self, rhs, |a, b| a.tot_eq_missing_kernel(b).into(), - "", + PlSmallStr::EMPTY, ), } } @@ -462,17 +507,22 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { if let Some(value) = rhs.get(0) { self.not_equal(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.not_equal(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_ne_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -497,7 +547,7 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { self, rhs, |a, b| a.tot_ne_missing_kernel(b).into(), - "", + PlSmallStr::EMPTY, ), } } @@ -509,17 +559,22 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { if let Some(value) = rhs.get(0) { self.lt(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.gt(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_lt_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_lt_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -530,17 +585,22 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked { if let Some(value) = rhs.get(0) { self.lt_eq(value) } else { - BooleanChunked::full_null("", self.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, self.len()) } }, (1, _) => { if let Some(value) = self.get(0) { rhs.gt_eq(value) } else { - BooleanChunked::full_null("", rhs.len()) + BooleanChunked::full_null(PlSmallStr::EMPTY, rhs.len()) } }, - _ => arity::binary_mut_values(self, rhs, |a, b| a.tot_le_kernel(b).into(), ""), + _ => arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_le_kernel(b).into(), + PlSmallStr::EMPTY, + ), } } @@ -560,13 +620,13 @@ where { match (lhs.len(), rhs.len()) { (_, 1) => { - let right = rhs.get_as_series(0).map(|s| s.with_name("")); + let right = rhs.get_as_series(0).map(|s| s.with_name(PlSmallStr::EMPTY)); lhs.amortized_iter() .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) .collect_trusted() }, (1, _) => { - let left = lhs.get_as_series(0).map(|s| s.with_name("")); + let left = lhs.get_as_series(0).map(|s| s.with_name(PlSmallStr::EMPTY)); rhs.amortized_iter() .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) .collect_trusted() @@ -657,7 +717,7 @@ where { if a.len() != b.len() || a.struct_fields().len() != b.struct_fields().len() { // polars_ensure!(a.len() == 1 || b.len() == 1, ShapeMismatch: "length lhs: {}, length rhs: {}", a.len(), b.len()); - BooleanChunked::full("", value, a.len()) + BooleanChunked::full(PlSmallStr::EMPTY, value, a.len()) } else { let (a, b) = align_chunks_binary(a, b); let mut out = a @@ -729,30 +789,50 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked { type Item = BooleanChunked; fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked { if self.width() != rhs.width() { - return BooleanChunked::full("", false, self.len()); + return BooleanChunked::full(PlSmallStr::EMPTY, false, self.len()); } - arity::binary_mut_values(self, rhs, |a, b| a.tot_eq_kernel(b).into(), "") + arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_eq_kernel(b).into(), + PlSmallStr::EMPTY, + ) } fn equal_missing(&self, rhs: &ArrayChunked) -> BooleanChunked { if self.width() != rhs.width() { - return BooleanChunked::full("", false, self.len()); + return BooleanChunked::full(PlSmallStr::EMPTY, false, self.len()); } - arity::binary_mut_with_options(self, rhs, |a, b| a.tot_eq_missing_kernel(b).into(), "") + arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_eq_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ) } fn not_equal(&self, rhs: &ArrayChunked) -> BooleanChunked { if self.width() != rhs.width() { - return BooleanChunked::full("", true, self.len()); + return BooleanChunked::full(PlSmallStr::EMPTY, true, self.len()); } - arity::binary_mut_values(self, rhs, |a, b| a.tot_ne_kernel(b).into(), "") + arity::binary_mut_values( + self, + rhs, + |a, b| a.tot_ne_kernel(b).into(), + PlSmallStr::EMPTY, + ) } fn not_equal_missing(&self, rhs: &ArrayChunked) -> Self::Item { if self.width() != rhs.width() { - return BooleanChunked::full("", true, self.len()); + return BooleanChunked::full(PlSmallStr::EMPTY, true, self.len()); } - arity::binary_mut_with_options(self, rhs, |a, b| a.tot_ne_missing_kernel(b).into(), "") + arity::binary_mut_with_options( + self, + rhs, + |a, b| a.tot_ne_missing_kernel(b).into(), + PlSmallStr::EMPTY, + ) } // following are not implemented because gt, lt comparison of series don't make sense @@ -778,7 +858,7 @@ impl Not for &BooleanChunked { fn not(self) -> Self::Output { let chunks = self.downcast_iter().map(compute::boolean::not); - ChunkedArray::from_chunk_iter(self.name(), chunks) + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) } } @@ -913,17 +993,20 @@ mod test { use crate::prelude::*; pub(crate) fn create_two_chunked() -> (Int32Chunked, Int32Chunked) { - let mut a1 = Int32Chunked::new("a", &[1, 2, 3]); - let a2 = Int32Chunked::new("a", &[4, 5, 6]); - let a3 = Int32Chunked::new("a", &[1, 2, 3, 4, 5, 6]); - a1.append(&a2); + let mut a1 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let a2 = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]); + let a3 = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3, 4, 5, 6]); + a1.append(&a2).unwrap(); (a1, a3) } #[test] fn test_bitwise_ops() { - let a = BooleanChunked::new("a", &[true, false, false]); - let b = BooleanChunked::new("b", &[Some(true), Some(true), None]); + let a = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, false, false]); + let b = BooleanChunked::new( + PlSmallStr::from_static("b"), + &[Some(true), Some(true), None], + ); assert_eq!(Vec::from(&a | &b), &[Some(true), Some(true), None]); assert_eq!(Vec::from(&a & &b), &[Some(true), Some(false), Some(false)]); assert_eq!(Vec::from(!b), &[Some(false), Some(false), None]); @@ -1049,7 +1132,9 @@ mod test { let a2: Int32Chunked = [Some(1), Some(2), Some(3)].iter().copied().collect(); let mut a2_2chunks: Int32Chunked = [Some(1), Some(2)].iter().copied().collect(); - a2_2chunks.append(&[Some(3)].iter().copied().collect()); + a2_2chunks + .append(&[Some(3)].iter().copied().collect()) + .unwrap(); assert_eq!( a1.equal(&a2).into_iter().collect::>(), @@ -1129,9 +1214,9 @@ mod test { #[test] fn test_kleene() { - let a = BooleanChunked::new("", &[Some(true), Some(false), None]); - let trues = BooleanChunked::from_slice("", &[true, true, true]); - let falses = BooleanChunked::from_slice("", &[false, false, false]); + let a = BooleanChunked::new(PlSmallStr::EMPTY, &[Some(true), Some(false), None]); + let trues = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true, true, true]); + let falses = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false, false, false]); let c = &a | &trues; assert_eq!(Vec::from(&c), &[Some(true), Some(true), Some(true)]); @@ -1142,9 +1227,9 @@ mod test { #[test] fn list_broadcasting_lists() { - let s_el = Series::new("", &[1, 2, 3]); - let s_lhs = Series::new("", &[s_el.clone(), s_el.clone()]); - let s_rhs = Series::new("", &[s_el.clone()]); + let s_el = Series::new(PlSmallStr::EMPTY, &[1, 2, 3]); + let s_lhs = Series::new(PlSmallStr::EMPTY, &[s_el.clone(), s_el.clone()]); + let s_rhs = Series::new(PlSmallStr::EMPTY, &[s_el.clone()]); let result = s_lhs.list().unwrap().equal(s_rhs.list().unwrap()); assert_eq!(result.len(), 2); @@ -1153,9 +1238,9 @@ mod test { #[test] fn test_broadcasting_bools() { - let a = BooleanChunked::from_slice("", &[true, false, true]); - let true_ = BooleanChunked::from_slice("", &[true]); - let false_ = BooleanChunked::from_slice("", &[false]); + let a = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true, false, true]); + let true_ = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true]); + let false_ = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false]); let out = a.equal(&true_); assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(true)]); @@ -1211,9 +1296,10 @@ mod test { let out = false_.lt_eq(&a); assert_eq!(Vec::from(&out), &[Some(true), Some(true), Some(true)]); - let a = BooleanChunked::from_slice_options("", &[Some(true), Some(false), None]); - let all_true = BooleanChunked::from_slice("", &[true, true, true]); - let all_false = BooleanChunked::from_slice("", &[false, false, false]); + let a = + BooleanChunked::from_slice_options(PlSmallStr::EMPTY, &[Some(true), Some(false), None]); + let all_true = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[true, true, true]); + let all_false = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false, false, false]); let out = a.equal(&true_); assert_eq!(Vec::from(&out), &[Some(true), Some(false), None]); let out = a.not_equal(&true_); @@ -1235,9 +1321,9 @@ mod test { #[test] fn test_broadcasting_numeric() { - let a = Int32Chunked::from_slice("", &[1, 2, 3]); - let one = Int32Chunked::from_slice("", &[1]); - let three = Int32Chunked::from_slice("", &[3]); + let a = Int32Chunked::from_slice(PlSmallStr::EMPTY, &[1, 2, 3]); + let one = Int32Chunked::from_slice(PlSmallStr::EMPTY, &[1]); + let three = Int32Chunked::from_slice(PlSmallStr::EMPTY, &[3]); let out = a.equal(&one); assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs index f47f23780f82..1c632299c1e4 100644 --- a/crates/polars-core/src/chunked_array/comparison/scalar.rs +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -56,7 +56,7 @@ where BooleanArray::from_data_default(mask.into(), None) }); - let mut ca = BooleanChunked::from_chunk_iter(ca.name(), chunks); + let mut ca = BooleanChunked::from_chunk_iter(ca.name().clone(), chunks); ca.set_sorted_flag(output_order.unwrap_or(IsSorted::Ascending)); ca } @@ -235,7 +235,7 @@ mod test { #[test] fn test_binary_search_cmp() { - let mut s = Series::new("", &[1, 1, 2, 2, 4, 8]); + let mut s = Series::new(PlSmallStr::EMPTY, &[1, 1, 2, 2, 4, 8]); s.set_sorted_flag(IsSorted::Ascending); let out = s.gt(10).unwrap(); assert!(!out.any()); @@ -246,12 +246,12 @@ mod test { let out = s.gt(2).unwrap(); assert_eq!( out.into_series(), - Series::new("", [false, false, false, false, true, true]) + Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true]) ); let out = s.gt(3).unwrap(); assert_eq!( out.into_series(), - Series::new("", [false, false, false, false, true, true]) + Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true]) ); let out = s.gt_eq(10).unwrap(); @@ -262,12 +262,12 @@ mod test { let out = s.gt_eq(2).unwrap(); assert_eq!( out.into_series(), - Series::new("", [false, false, true, true, true, true]) + Series::new(PlSmallStr::EMPTY, [false, false, true, true, true, true]) ); let out = s.gt_eq(3).unwrap(); assert_eq!( out.into_series(), - Series::new("", [false, false, false, false, true, true]) + Series::new(PlSmallStr::EMPTY, [false, false, false, false, true, true]) ); let out = s.lt(10).unwrap(); @@ -278,12 +278,12 @@ mod test { let out = s.lt(2).unwrap(); assert_eq!( out.into_series(), - Series::new("", [true, true, false, false, false, false]) + Series::new(PlSmallStr::EMPTY, [true, true, false, false, false, false]) ); let out = s.lt(3).unwrap(); assert_eq!( out.into_series(), - Series::new("", [true, true, true, true, false, false]) + Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false]) ); let out = s.lt_eq(10).unwrap(); @@ -294,12 +294,12 @@ mod test { let out = s.lt_eq(2).unwrap(); assert_eq!( out.into_series(), - Series::new("", [true, true, true, true, false, false]) + Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false]) ); let out = s.lt(3).unwrap(); assert_eq!( out.into_series(), - Series::new("", [true, true, true, true, false, false]) + Series::new(PlSmallStr::EMPTY, [true, true, true, true, false, false]) ); } } diff --git a/crates/polars-core/src/chunked_array/float.rs b/crates/polars-core/src/chunked_array/float.rs index dc09024b704c..8376629cc403 100644 --- a/crates/polars-core/src/chunked_array/float.rs +++ b/crates/polars-core/src/chunked_array/float.rs @@ -30,7 +30,7 @@ where let chunks = self .downcast_iter() .map(|arr| set_at_nulls(arr, T::Native::nan())); - ChunkedArray::from_chunk_iter(self.name(), chunks) + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) } } diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 74f12ccc58ce..bf5c748eeed1 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -6,7 +6,7 @@ use super::*; fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataType { // ensure we don't get List let dtype = if let Some(arr) = chunks.get(0) { - arr.data_type().into() + arr.dtype().into() } else { dtype }; @@ -27,9 +27,9 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy let values_arr = list_arr.values(); let cat = unsafe { Series::_try_from_arrow_unchecked( - "", + PlSmallStr::EMPTY, vec![values_arr.clone()], - values_arr.data_type(), + values_arr.dtype(), ) .unwrap() }; @@ -59,9 +59,9 @@ fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataTy let values_arr = list_arr.values(); let cat = unsafe { Series::_try_from_arrow_unchecked( - "", + PlSmallStr::EMPTY, vec![values_arr.clone()], - values_arr.data_type(), + values_arr.dtype(), ) .unwrap() }; @@ -88,7 +88,7 @@ where A: Array, { fn from(arr: A) -> Self { - Self::with_chunk("", arr) + Self::with_chunk(PlSmallStr::EMPTY, arr) } } @@ -96,7 +96,7 @@ impl ChunkedArray where T: PolarsDataType, { - pub fn with_chunk(name: &str, arr: A) -> Self + pub fn with_chunk(name: PlSmallStr, arr: A) -> Self where A: Array, T: PolarsDataType, @@ -112,7 +112,7 @@ where Self::from_chunk_iter_like(ca, std::iter::once(arr)) } - pub fn from_chunk_iter(name: &str, iter: I) -> Self + pub fn from_chunk_iter(name: PlSmallStr, iter: I) -> Self where I: IntoIterator, T: PolarsDataType::Item>, @@ -135,10 +135,12 @@ where .into_iter() .map(|x| Box::new(x) as Box) .collect(); - unsafe { Self::from_chunks_and_dtype_unchecked(ca.name(), chunks, ca.dtype().clone()) } + unsafe { + Self::from_chunks_and_dtype_unchecked(ca.name().clone(), chunks, ca.dtype().clone()) + } } - pub fn try_from_chunk_iter(name: &str, iter: I) -> Result + pub fn try_from_chunk_iter(name: PlSmallStr, iter: I) -> Result where I: IntoIterator>, T: PolarsDataType, @@ -187,7 +189,7 @@ where /// /// # Safety /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. - pub unsafe fn from_chunks(name: &str, mut chunks: Vec) -> Self { + pub unsafe fn from_chunks(name: PlSmallStr, mut chunks: Vec) -> Self { let dtype = match T::get_dtype() { dtype @ DataType::List(_) => from_chunks_list_dtype(&mut chunks, dtype), #[cfg(feature = "dtype-array")] @@ -210,7 +212,7 @@ where /// # Safety /// The Arrow datatype of all chunks must match the [`PolarsDataType`] `T`. pub unsafe fn from_chunks_and_dtype( - name: &str, + name: PlSmallStr, chunks: Vec, dtype: DataType, ) -> Self { @@ -219,10 +221,7 @@ where #[cfg(debug_assertions)] { if !chunks.is_empty() && !chunks[0].is_empty() && dtype.is_primitive() { - assert_eq!( - chunks[0].data_type(), - &dtype.to_arrow(CompatLevel::newest()) - ) + assert_eq!(chunks[0].dtype(), &dtype.to_arrow(CompatLevel::newest())) } } let field = Arc::new(Field::new(name, dtype)); @@ -230,7 +229,7 @@ where } pub(crate) unsafe fn from_chunks_and_dtype_unchecked( - name: &str, + name: PlSmallStr, chunks: Vec, dtype: DataType, ) -> Self { @@ -252,12 +251,16 @@ where T: PolarsNumericType, { /// Create a new ChunkedArray by taking ownership of the Vec. This operation is zero copy. - pub fn from_vec(name: &str, v: Vec) -> Self { + pub fn from_vec(name: PlSmallStr, v: Vec) -> Self { Self::with_chunk(name, to_primitive::(v, None)) } /// Create a new ChunkedArray from a Vec and a validity mask. - pub fn from_vec_validity(name: &str, values: Vec, buffer: Option) -> Self { + pub fn from_vec_validity( + name: PlSmallStr, + values: Vec, + buffer: Option, + ) -> Self { let arr = to_array::(values, buffer); ChunkedArray::new_with_compute_len(Arc::new(Field::new(name, T::get_dtype())), vec![arr]) } @@ -267,7 +270,7 @@ where /// # Safety /// The lifetime will be bound to the lifetime of the slice. /// This will not be checked by the borrowchecker. - pub unsafe fn mmap_slice(name: &str, values: &[T::Native]) -> Self { + pub unsafe fn mmap_slice(name: PlSmallStr, values: &[T::Native]) -> Self { Self::with_chunk(name, arrow::ffi::mmap::slice(values)) } } @@ -278,7 +281,7 @@ impl BooleanChunked { /// # Safety /// The lifetime will be bound to the lifetime of the slice. /// This will not be checked by the borrowchecker. - pub unsafe fn mmap_slice(name: &str, values: &[u8], offset: usize, len: usize) -> Self { + pub unsafe fn mmap_slice(name: PlSmallStr, values: &[u8], offset: usize, len: usize) -> Self { let arr = arrow::ffi::mmap::bitmap(values, offset, len).unwrap(); Self::with_chunk(name, arr) } diff --git a/crates/polars-core/src/chunked_array/from_iterator.rs b/crates/polars-core/src/chunked_array/from_iterator.rs index 8d2499983d73..766ef94acc8e 100644 --- a/crates/polars-core/src/chunked_array/from_iterator.rs +++ b/crates/polars-core/src/chunked_array/from_iterator.rs @@ -20,7 +20,7 @@ where #[inline] fn from_iter>>(iter: I) -> Self { // TODO: eliminate this FromIterator implementation entirely. - iter.into_iter().collect_ca("") + iter.into_iter().collect_ca(PlSmallStr::EMPTY) } } @@ -35,7 +35,7 @@ where fn from_iter>(iter: I) -> Self { // 2021-02-07: aligned vec was ~2x faster than arrow collect. let av = iter.into_iter().collect::>(); - NoNull::new(ChunkedArray::from_vec("", av)) + NoNull::new(ChunkedArray::from_vec(PlSmallStr::EMPTY, av)) } } @@ -49,14 +49,14 @@ impl FromIterator> for ChunkedArray { impl FromIterator for BooleanChunked { #[inline] fn from_iter>(iter: I) -> Self { - iter.into_iter().collect_ca("") + iter.into_iter().collect_ca(PlSmallStr::EMPTY) } } impl FromIterator for NoNull { #[inline] fn from_iter>(iter: I) -> Self { - NoNull::new(iter.into_iter().collect_ca("")) + NoNull::new(iter.into_iter().collect_ca(PlSmallStr::EMPTY)) } } @@ -69,7 +69,7 @@ where #[inline] fn from_iter>>(iter: I) -> Self { let arr = MutableBinaryViewArray::from_iterator(iter.into_iter()).freeze(); - ChunkedArray::with_chunk("", arr) + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) } } @@ -95,7 +95,7 @@ where #[inline] fn from_iter>(iter: I) -> Self { let arr = MutableBinaryViewArray::from_values_iter(iter.into_iter()).freeze(); - ChunkedArray::with_chunk("", arr) + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) } } @@ -107,7 +107,7 @@ where #[inline] fn from_iter>>(iter: I) -> Self { let arr = MutableBinaryViewArray::from_iter(iter).freeze(); - ChunkedArray::with_chunk("", arr) + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) } } @@ -118,7 +118,7 @@ where #[inline] fn from_iter>(iter: I) -> Self { let arr = MutableBinaryViewArray::from_values_iter(iter.into_iter()).freeze(); - ChunkedArray::with_chunk("", arr) + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) } } @@ -134,11 +134,16 @@ where // first take one to get the dtype. let v = match it.next() { Some(v) => v, - None => return ListChunked::full_null("", 0), + None => return ListChunked::full_null(PlSmallStr::EMPTY, 0), }; // We don't know the needed capacity. We arbitrarily choose an average of 5 elements per series. - let mut builder = - get_list_builder(v.borrow().dtype(), capacity * 5, capacity, "collected").unwrap(); + let mut builder = get_list_builder( + v.borrow().dtype(), + capacity * 5, + capacity, + PlSmallStr::EMPTY, + ) + .unwrap(); builder.append_series(v.borrow()).unwrap(); for s in it { @@ -166,7 +171,7 @@ impl FromIterator> for ListChunked { Some(None) => { init_null_count += 1; }, - None => return ListChunked::full_null("", init_null_count), + None => return ListChunked::full_null(PlSmallStr::EMPTY, init_null_count), } } @@ -182,7 +187,8 @@ impl FromIterator> for ListChunked { // the empty arrays is then not added (we add an extra offset instead) // the next non-empty series then must have the correct dtype. if matches!(first_s.dtype(), DataType::Null) && first_s.is_empty() { - let mut builder = AnonymousOwnedListBuilder::new("collected", capacity, None); + let mut builder = + AnonymousOwnedListBuilder::new(PlSmallStr::EMPTY, capacity, None); for _ in 0..init_null_count { builder.append_null(); } @@ -197,7 +203,7 @@ impl FromIterator> for ListChunked { #[cfg(feature = "object")] DataType::Object(_, _) => { let mut builder = - first_s.get_list_builder("collected", capacity * 5, capacity); + first_s.get_list_builder(PlSmallStr::EMPTY, capacity * 5, capacity); for _ in 0..init_null_count { builder.append_null(); } @@ -214,7 +220,7 @@ impl FromIterator> for ListChunked { first_s.dtype(), capacity * 5, capacity, - "collected", + PlSmallStr::EMPTY, ) .unwrap(); @@ -238,7 +244,7 @@ impl FromIterator> for ListChunked { impl FromIterator>> for ListChunked { #[inline] fn from_iter>>>(iter: I) -> Self { - iter.into_iter().collect_ca("collected") + iter.into_iter().collect_ca(PlSmallStr::EMPTY) } } @@ -274,7 +280,7 @@ impl FromIterator> for ObjectChunked { len, }); ChunkedArray::new_with_compute_len( - Arc::new(Field::new("", get_object_type::())), + Arc::new(Field::new(PlSmallStr::EMPTY, get_object_type::())), vec![arr], ) } diff --git a/crates/polars-core/src/chunked_array/from_iterator_par.rs b/crates/polars-core/src/chunked_array/from_iterator_par.rs index eaf45d1d651f..5c9abf4620af 100644 --- a/crates/polars-core/src/chunked_array/from_iterator_par.rs +++ b/crates/polars-core/src/chunked_array/from_iterator_par.rs @@ -72,7 +72,7 @@ where let vectors = collect_into_linked_list_vec(iter); let vectors = vectors.into_iter().collect::>(); let values = flatten_par(&vectors); - NoNull::new(ChunkedArray::new_vec("", values)) + NoNull::new(ChunkedArray::new_vec(PlSmallStr::EMPTY, values)) } } @@ -82,21 +82,21 @@ where { fn from_par_iter>>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutablePrimitiveArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } impl FromParallelIterator for BooleanChunked { fn from_par_iter>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutableBooleanArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } impl FromParallelIterator> for BooleanChunked { fn from_par_iter>>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutableBooleanArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } @@ -106,7 +106,7 @@ where { fn from_par_iter>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } @@ -116,7 +116,7 @@ where { fn from_par_iter>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } @@ -126,7 +126,7 @@ where { fn from_par_iter>>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } @@ -136,12 +136,12 @@ where { fn from_par_iter>>(iter: I) -> Self { let chunks = collect_into_linked_list(iter, MutableBinaryViewArray::new); - Self::from_chunk_iter("", chunks).optional_rechunk() + Self::from_chunk_iter(PlSmallStr::EMPTY, chunks).optional_rechunk() } } pub trait FromParIterWithDtype { - fn from_par_iter_with_dtype(iter: I, name: &str, dtype: DataType) -> Self + fn from_par_iter_with_dtype(iter: I, name: PlSmallStr, dtype: DataType) -> Self where I: IntoParallelIterator, Self: Sized; @@ -171,7 +171,7 @@ fn get_dtype(vectors: &LinkedList>>) -> DataType { } fn materialize_list( - name: &str, + name: PlSmallStr, vectors: &LinkedList>>, dtype: DataType, value_capacity: usize, @@ -217,15 +217,21 @@ impl FromParallelIterator> for ListChunked { let value_capacity = get_value_cap(&vectors); let dtype = get_dtype(&vectors); if let DataType::Null = dtype { - ListChunked::full_null_with_dtype("", list_capacity, &DataType::Null) + ListChunked::full_null_with_dtype(PlSmallStr::EMPTY, list_capacity, &DataType::Null) } else { - materialize_list("", &vectors, dtype, value_capacity, list_capacity) + materialize_list( + PlSmallStr::EMPTY, + &vectors, + dtype, + value_capacity, + list_capacity, + ) } } } impl FromParIterWithDtype> for ListChunked { - fn from_par_iter_with_dtype(iter: I, name: &str, dtype: DataType) -> Self + fn from_par_iter_with_dtype(iter: I, name: PlSmallStr, dtype: DataType) -> Self where I: IntoParallelIterator>, Self: Sized, @@ -245,7 +251,7 @@ impl FromParIterWithDtype> for ListChunked { pub trait ChunkedCollectParIterExt: ParallelIterator { fn collect_ca_with_dtype>( self, - name: &str, + name: PlSmallStr, dtype: DataType, ) -> B where @@ -264,7 +270,7 @@ where T: Send, E: Send, { - fn from_par_iter_with_dtype(par_iter: I, name: &str, dtype: DataType) -> Self + fn from_par_iter_with_dtype(par_iter: I, name: PlSmallStr, dtype: DataType) -> Self where I: IntoParallelIterator>, { diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index 64b7fb30bb7c..7756153891c6 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -220,7 +220,7 @@ impl<'a> IntoIterator for &'a ListChunked { .trust_my_length(self.len()) .map(move |arr| { Some(Series::from_chunks_and_dtype_unchecked( - "", + PlSmallStr::EMPTY, vec![arr], dtype, )) @@ -236,7 +236,11 @@ impl<'a> IntoIterator for &'a ListChunked { .trust_my_length(self.len()) .map(move |arr| { arr.map(|arr| { - Series::from_chunks_and_dtype_unchecked("", vec![arr], dtype) + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + dtype, + ) }) }), ) @@ -256,7 +260,13 @@ impl ListChunked { unsafe { self.downcast_iter() .flat_map(|arr| arr.values_iter()) - .map(move |arr| Series::from_chunks_and_dtype_unchecked("", vec![arr], inner_type)) + .map(move |arr| { + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + inner_type, + ) + }) .trust_my_length(self.len()) } } @@ -278,7 +288,7 @@ impl<'a> IntoIterator for &'a ArrayChunked { .trust_my_length(self.len()) .map(move |arr| { Some(Series::from_chunks_and_dtype_unchecked( - "", + PlSmallStr::EMPTY, vec![arr], dtype, )) @@ -294,7 +304,11 @@ impl<'a> IntoIterator for &'a ArrayChunked { .trust_my_length(self.len()) .map(move |arr| { arr.map(|arr| { - Series::from_chunks_and_dtype_unchecked("", vec![arr], dtype) + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr], + dtype, + ) }) }), ) @@ -336,7 +350,7 @@ impl<'a> Iterator for FixedSizeListIterNoNull<'a> { self.current += 1; unsafe { Some(Series::from_chunks_and_dtype_unchecked( - "", + PlSmallStr::EMPTY, vec![self.array.value_unchecked(old)], &self.inner_type, )) @@ -360,7 +374,13 @@ impl<'a> DoubleEndedIterator for FixedSizeListIterNoNull<'a> { } else { self.current_end -= 1; unsafe { - Some(Series::try_from(("", self.array.value_unchecked(self.current_end))).unwrap()) + Some( + Series::try_from(( + PlSmallStr::EMPTY, + self.array.value_unchecked(self.current_end), + )) + .unwrap(), + ) } } } @@ -456,9 +476,9 @@ mod test { #[test] fn out_of_bounds() { - let mut a = UInt32Chunked::from_slice("a", &[1, 2, 3]); - let b = UInt32Chunked::from_slice("a", &[1, 2, 3]); - a.append(&b); + let mut a = UInt32Chunked::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3]); + let b = UInt32Chunked::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3]); + a.append(&b).unwrap(); let v = a.into_iter().collect::>(); assert_eq!( @@ -482,7 +502,10 @@ mod test { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] fn $test_name() { - let a = <$ca_type>::from_slice("test", &[$first_val, $second_val, $third_val]); + let a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val, $third_val], + ); // normal iterator let mut it = a.into_iter(); @@ -543,7 +566,10 @@ mod test { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] fn $test_name() { - let a = <$ca_type>::new("test", &[$first_val, $second_val, $third_val]); + let a = <$ca_type>::new( + PlSmallStr::from_static("test"), + &[$first_val, $second_val, $third_val], + ); // normal iterator let mut it = a.into_iter(); @@ -622,9 +648,12 @@ mod test { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] fn $test_name() { - let mut a = <$ca_type>::from_slice("test", &[$first_val, $second_val]); - let a_b = <$ca_type>::from_slice("", &[$third_val]); - a.append(&a_b); + let mut a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val], + ); + let a_b = <$ca_type>::from_slice(PlSmallStr::EMPTY, &[$third_val]); + a.append(&a_b).unwrap(); // normal iterator let mut it = a.into_iter(); @@ -685,9 +714,10 @@ mod test { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] fn $test_name() { - let mut a = <$ca_type>::new("test", &[$first_val, $second_val]); - let a_b = <$ca_type>::new("", &[$third_val]); - a.append(&a_b); + let mut a = + <$ca_type>::new(PlSmallStr::from_static("test"), &[$first_val, $second_val]); + let a_b = <$ca_type>::new(PlSmallStr::EMPTY, &[$third_val]); + a.append(&a_b).unwrap(); // normal iterator let mut it = a.into_iter(); @@ -766,7 +796,10 @@ mod test { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] fn $test_name() { - let a = <$ca_type>::from_slice("test", &[$first_val, $second_val, $third_val]); + let a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val, $third_val], + ); // normal iterator let mut it = a.into_no_null_iter(); @@ -839,9 +872,12 @@ mod test { ($test_name:ident, $ca_type:ty, $first_val:expr, $second_val:expr, $third_val:expr) => { #[test] fn $test_name() { - let mut a = <$ca_type>::from_slice("test", &[$first_val, $second_val]); - let a_b = <$ca_type>::from_slice("", &[$third_val]); - a.append(&a_b); + let mut a = <$ca_type>::from_slice( + PlSmallStr::from_static("test"), + &[$first_val, $second_val], + ); + let a_b = <$ca_type>::from_slice(PlSmallStr::EMPTY, &[$third_val]); + a.append(&a_b).unwrap(); // normal iterator let mut it = a.into_no_null_iter(); @@ -946,7 +982,10 @@ mod test { } impl_test_iter_skip!(utf8_iter_single_chunk_skip, 8, Some("0"), Some("9"), { - StringChunked::from_slice("test", &generate_utf8_vec(SKIP_ITERATOR_SIZE)) + StringChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_utf8_vec(SKIP_ITERATOR_SIZE), + ) }); impl_test_iter_skip!( @@ -954,20 +993,37 @@ mod test { 8, Some("0"), None, - { StringChunked::new("test", &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE)) } + { + StringChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE), + ) + } ); impl_test_iter_skip!(utf8_iter_many_chunk_skip, 18, Some("0"), Some("9"), { - let mut a = StringChunked::from_slice("test", &generate_utf8_vec(SKIP_ITERATOR_SIZE)); - let a_b = StringChunked::from_slice("test", &generate_utf8_vec(SKIP_ITERATOR_SIZE)); - a.append(&a_b); + let mut a = StringChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_utf8_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = StringChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_utf8_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); a }); impl_test_iter_skip!(utf8_iter_many_chunk_null_check_skip, 18, Some("0"), None, { - let mut a = StringChunked::new("test", &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE)); - let a_b = StringChunked::new("test", &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE)); - a.append(&a_b); + let mut a = StringChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = StringChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_utf8_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); a }); @@ -987,24 +1043,42 @@ mod test { } impl_test_iter_skip!(bool_iter_single_chunk_skip, 8, Some(true), Some(false), { - BooleanChunked::from_slice("test", &generate_boolean_vec(SKIP_ITERATOR_SIZE)) + BooleanChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_boolean_vec(SKIP_ITERATOR_SIZE), + ) }); impl_test_iter_skip!(bool_iter_single_chunk_null_check_skip, 8, None, None, { - BooleanChunked::new("test", &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE)) + BooleanChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE), + ) }); impl_test_iter_skip!(bool_iter_many_chunk_skip, 18, Some(true), Some(false), { - let mut a = BooleanChunked::from_slice("test", &generate_boolean_vec(SKIP_ITERATOR_SIZE)); - let a_b = BooleanChunked::from_slice("test", &generate_boolean_vec(SKIP_ITERATOR_SIZE)); - a.append(&a_b); + let mut a = BooleanChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_boolean_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = BooleanChunked::from_slice( + PlSmallStr::from_static("test"), + &generate_boolean_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); a }); impl_test_iter_skip!(bool_iter_many_chunk_null_check_skip, 18, None, None, { - let mut a = BooleanChunked::new("test", &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE)); - let a_b = BooleanChunked::new("test", &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE)); - a.append(&a_b); + let mut a = BooleanChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE), + ); + let a_b = BooleanChunked::new( + PlSmallStr::from_static("test"), + &generate_opt_boolean_vec(SKIP_ITERATOR_SIZE), + ); + a.append(&a_b).unwrap(); a }); } diff --git a/crates/polars-core/src/chunked_array/iterator/par/list.rs b/crates/polars-core/src/chunked_array/iterator/par/list.rs index 63eb77753215..05e02e0ccf01 100644 --- a/crates/polars-core/src/chunked_array/iterator/par/list.rs +++ b/crates/polars-core/src/chunked_array/iterator/par/list.rs @@ -4,8 +4,9 @@ use crate::prelude::*; unsafe fn idx_to_array(idx: usize, arr: &ListArray, dtype: &DataType) -> Option { if arr.is_valid(idx) { - Some(arr.value_unchecked(idx)) - .map(|arr: ArrayRef| Series::from_chunks_and_dtype_unchecked("", vec![arr], dtype)) + Some(arr.value_unchecked(idx)).map(|arr: ArrayRef| { + Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![arr], dtype) + }) } else { None } diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 19d6fc90c952..625eff6c9559 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -51,7 +51,7 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a // dtype is known unsafe { let s = Series::from_chunks_and_dtype_unchecked( - "", + PlSmallStr::EMPTY, vec![array_ref], &self.inner_dtype.to_physical(), ) @@ -69,7 +69,7 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a { let (s, ptr) = unsafe { unstable_series_container_and_ptr( - self.series_container.name(), + self.series_container.name().clone(), array_ref, self.series_container.dtype(), ) @@ -123,13 +123,13 @@ impl ListChunked { /// If the returned `AmortSeries` is cloned, the local copy will be replaced and a new container /// will be set. pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { - self.amortized_iter_with_name("") + self.amortized_iter_with_name(PlSmallStr::EMPTY) } /// See `amortized_iter`. pub fn amortized_iter_with_name( &self, - name: &str, + name: PlSmallStr, ) -> AmortizedListIter> + '_> { // we create the series container from the inner array // so that the container has the proper dtype. @@ -172,7 +172,7 @@ impl ListChunked { V::Array: ArrayFromIter>, { // TODO! make an amortized iter that does not flatten - self.amortized_iter().map(f).collect_ca(self.name()) + self.amortized_iter().map(f).collect_ca(self.name().clone()) } pub fn try_apply_amortized_generic(&self, f: F) -> PolarsResult> @@ -182,7 +182,9 @@ impl ListChunked { V::Array: ArrayFromIter>, { // TODO! make an amortized iter that does not flatten - self.amortized_iter().map(f).try_collect_ca(self.name()) + self.amortized_iter() + .map(f) + .try_collect_ca(self.name().clone()) } pub fn for_each_amortized(&self, f: F) @@ -224,7 +226,7 @@ impl ListChunked { .collect_trusted() }; - out.rename(self.name()); + out.rename(self.name().clone()); if fast_explode { out.set_fast_explode(); } @@ -271,7 +273,7 @@ impl ListChunked { .collect_trusted() }; - out.rename(self.name()); + out.rename(self.name().clone()); if fast_explode { out.set_fast_explode(); } @@ -312,7 +314,7 @@ impl ListChunked { .collect::>()? }; - out.rename(self.name()); + out.rename(self.name().clone()); if fast_explode { out.set_fast_explode(); } @@ -343,7 +345,7 @@ impl ListChunked { .collect_trusted() }; - ca.rename(self.name()); + ca.rename(self.name().clone()); if fast_explode { ca.set_fast_explode(); } @@ -375,7 +377,7 @@ impl ListChunked { }) .collect::>()? }; - ca.rename(self.name()); + ca.rename(self.name().clone()); if fast_explode { ca.set_fast_explode(); } @@ -390,10 +392,16 @@ mod test { #[test] fn test_iter_list() { - let mut builder = get_list_builder(&DataType::Int32, 10, 10, "").unwrap(); - builder.append_series(&Series::new("", &[1, 2, 3])).unwrap(); - builder.append_series(&Series::new("", &[3, 2, 1])).unwrap(); - builder.append_series(&Series::new("", &[1, 1])).unwrap(); + let mut builder = get_list_builder(&DataType::Int32, 10, 10, PlSmallStr::EMPTY).unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 2, 3])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[3, 2, 1])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 1])) + .unwrap(); let ca = builder.finish(); ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| { diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index 81adf6a48d30..8b730966b1bc 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -41,7 +41,9 @@ impl ListChunked { let chunks: Vec<_> = self.downcast_iter().map(|c| c.values().clone()).collect(); // SAFETY: Data type of arrays matches because they are chunks from the same array. - unsafe { Series::from_chunks_and_dtype_unchecked(self.name(), chunks, self.inner_dtype()) } + unsafe { + Series::from_chunks_and_dtype_unchecked(self.name().clone(), chunks, self.inner_dtype()) + } } /// Returns an iterator over the offsets of this chunked array. @@ -76,7 +78,7 @@ impl ListChunked { // Inner dtype is passed correctly let elements = unsafe { Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![arr.values().clone()], ca.inner_dtype(), ) @@ -91,7 +93,7 @@ impl ListChunked { let out = out.rechunk(); let values = out.chunks()[0].clone(); - let inner_dtype = LargeListArray::default_datatype(values.data_type().clone()); + let inner_dtype = LargeListArray::default_datatype(values.dtype().clone()); let arr = LargeListArray::new( inner_dtype, (*arr.offsets()).clone(), @@ -102,7 +104,7 @@ impl ListChunked { // SAFETY: arr's inner dtype is derived from out dtype. Ok(unsafe { ListChunked::from_chunks_and_dtype_unchecked( - ca.name(), + ca.name().clone(), vec![Box::new(arr)], DataType::List(Box::new(out.dtype().clone())), ) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index 6342f18f29b0..23739e71d965 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -1,7 +1,7 @@ use arrow::array::*; use arrow::legacy::trusted_len::TrustedLenPush; use hashbrown::hash_map::Entry; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; @@ -13,7 +13,7 @@ struct KeyWrapper(u32); pub struct CategoricalChunkedBuilder { cat_builder: UInt32Vec, - name: String, + name: PlSmallStr, ordering: CategoricalOrdering, categories: MutablePlString, // hashmap utilized by the local builder @@ -21,10 +21,10 @@ pub struct CategoricalChunkedBuilder { } impl CategoricalChunkedBuilder { - pub fn new(name: &str, capacity: usize, ordering: CategoricalOrdering) -> Self { + pub fn new(name: PlSmallStr, capacity: usize, ordering: CategoricalOrdering) -> Self { Self { cat_builder: UInt32Vec::with_capacity(capacity), - name: name.to_string(), + name, ordering, categories: MutablePlString::with_capacity(_HASHMAP_INIT_SIZE), local_mapping: PlHashMap::with_capacity_and_hasher( @@ -166,7 +166,7 @@ impl CategoricalChunkedBuilder { ); let indices = std::mem::take(&mut self.cat_builder).into(); - let indices = UInt32Chunked::with_chunk(&self.name, indices); + let indices = UInt32Chunked::with_chunk(self.name.clone(), indices); // SAFETY: indices are in bounds of new rev_map unsafe { @@ -196,7 +196,7 @@ impl CategoricalChunkedBuilder { // SAFETY: keys and values are in bounds unsafe { CategoricalChunked::from_keys_and_values( - &self.name, + self.name.clone(), &self.cat_builder.into(), &self.categories.into(), self.ordering, @@ -271,7 +271,7 @@ impl CategoricalChunked { } pub(crate) unsafe fn from_keys_and_values_global( - name: &str, + name: PlSmallStr, keys: impl IntoIterator> + Send, capacity: usize, values: &Utf8ViewArray, @@ -317,7 +317,7 @@ impl CategoricalChunked { } pub(crate) unsafe fn from_keys_and_values_local( - name: &str, + name: PlSmallStr, keys: &PrimitiveArray, values: &Utf8ViewArray, ordering: CategoricalOrdering, @@ -333,7 +333,7 @@ impl CategoricalChunked { /// # Safety /// The caller must ensure that index values in the `keys` are in within bounds of the `values` length. pub(crate) unsafe fn from_keys_and_values( - name: &str, + name: PlSmallStr, keys: &PrimitiveArray, values: &Utf8ViewArray, ordering: CategoricalOrdering, @@ -372,8 +372,8 @@ impl CategoricalChunked { .map(|opt_s: Option<&str>| opt_s.and_then(|s| map.get(s).copied())) .collect_arr() }); - let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name(), iter); - keys.rename(values.name()); + let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name().clone(), iter); + keys.rename(values.name().clone()); let rev_map = RevMapping::build_local(categories.clone()); unsafe { Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( @@ -403,7 +403,7 @@ mod test { Some("foo"), Some("bar"), ]; - let ca = StringChunked::new("a", slice); + let ca = StringChunked::new(PlSmallStr::from_static("a"), slice); let out = ca.cast(&DataType::Categorical(None, Default::default()))?; let out = out.categorical().unwrap().clone(); assert_eq!(out.get_rev_map().len(), 2); @@ -422,10 +422,10 @@ mod test { // Check that we don't panic if we append two categorical arrays // build under the same string cache // https://github.com/pola-rs/polars/issues/1115 - let ca1 = StringChunked::new("a", slice) + let ca1 = StringChunked::new(PlSmallStr::from_static("a"), slice) .cast(&DataType::Categorical(None, Default::default()))?; let mut ca1 = ca1.categorical().unwrap().clone(); - let ca2 = StringChunked::new("a", slice) + let ca2 = StringChunked::new(PlSmallStr::from_static("a"), slice) .cast(&DataType::Categorical(None, Default::default()))?; let ca2 = ca2.categorical().unwrap(); ca1.append(ca2).unwrap(); @@ -445,8 +445,16 @@ mod test { // Use 2 builders to check if the global string cache // does not interfere with the index mapping - let builder1 = CategoricalChunkedBuilder::new("foo", 10, Default::default()); - let builder2 = CategoricalChunkedBuilder::new("foo", 10, Default::default()); + let builder1 = CategoricalChunkedBuilder::new( + PlSmallStr::from_static("foo"), + 10, + Default::default(), + ); + let builder2 = CategoricalChunkedBuilder::new( + PlSmallStr::from_static("foo"), + 10, + Default::default(), + ); let s = builder1 .drain_iter_and_finish(vec![None, Some("hello"), Some("vietnam")]) .into_series(); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs index 375f8cc3e72f..0e72de7a903f 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -240,7 +240,7 @@ pub fn make_list_categoricals_compatible( .zip(cat_physical.chunks()) .for_each(|(arr, new_phys)| { *arr = ListArray::new( - arr.data_type().clone(), + arr.dtype().clone(), arr.offsets().clone(), new_phys.clone(), arr.validity().cloned(), diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 32d8a79691da..6f681afbb5b7 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -8,7 +8,7 @@ pub mod string_cache; use bitflags::bitflags; pub use builder::*; pub use merge::*; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use polars_utils::sync::SyncPtr; pub use revmap::*; @@ -37,7 +37,7 @@ pub struct CategoricalChunked { impl CategoricalChunked { pub(crate) fn field(&self) -> Field { let name = self.physical().name(); - Field::new(name, self.dtype().clone()) + Field::new(name.clone(), self.dtype().clone()) } pub fn is_empty(&self) -> bool { @@ -54,7 +54,7 @@ impl CategoricalChunked { self.physical.null_count() } - pub fn name(&self) -> &str { + pub fn name(&self) -> &PlSmallStr { self.physical.name() } @@ -122,7 +122,7 @@ impl CategoricalChunked { // SAFETY: keys and values are in bounds unsafe { Ok(CategoricalChunked::from_keys_and_values_global( - self.name(), + self.name().clone(), self.physical(), self.len(), categories, @@ -337,7 +337,8 @@ impl LogicalType for CategoricalChunked { DataType::String => { let mapping = &**self.get_rev_map(); - let mut builder = StringChunkedBuilder::new(self.physical.name(), self.len()); + let mut builder = + StringChunkedBuilder::new(self.physical.name().clone(), self.len()); let f = |idx: u32| mapping.get(idx); @@ -356,7 +357,10 @@ impl LogicalType for CategoricalChunked { }, DataType::UInt32 => { let ca = unsafe { - UInt32Chunked::from_chunks(self.physical.name(), self.physical.chunks.clone()) + UInt32Chunked::from_chunks( + self.physical.name().clone(), + self.physical.chunks.clone(), + ) }; Ok(ca.into_series()) }, @@ -369,7 +373,7 @@ impl LogicalType for CategoricalChunked { .to_enum(categories, *hash) .set_ordering(*ordering, true) .into_series() - .with_name(self.name())) + .with_name(self.name().clone())) }, DataType::Enum(None, _) => { polars_bail!(ComputeError: "can not cast to enum without categories present") @@ -393,7 +397,7 @@ impl LogicalType for CategoricalChunked { dt if dt.is_numeric() => { // Apply the cast to the categories and then index into the casted series let categories = StringChunked::with_chunk( - self.physical.name(), + self.physical.name().clone(), self.get_rev_map().get_categories().clone(), ); let casted_series = categories.cast_with_options(dtype, options)?; @@ -460,12 +464,12 @@ mod test { Some("foo"), Some("bar"), ]; - let ca = StringChunked::new("a", slice); + let ca = StringChunked::new(PlSmallStr::from_static("a"), slice); let ca = ca.cast(&DataType::Categorical(None, Default::default()))?; let ca = ca.categorical().unwrap(); let arr = ca.to_arrow(CompatLevel::newest(), false); - let s = Series::try_from(("foo", arr))?; + let s = Series::try_from((PlSmallStr::from_static("foo"), arr))?; assert!(matches!(s.dtype(), &DataType::Categorical(_, _))); assert_eq!(s.null_count(), 1); assert_eq!(s.len(), 6); @@ -479,10 +483,10 @@ mod test { disable_string_cache(); enable_string_cache(); - let mut s1 = Series::new("1", vec!["a", "b", "c"]) + let mut s1 = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"]) .cast(&DataType::Categorical(None, Default::default())) .unwrap(); - let s2 = Series::new("2", vec!["a", "x", "y"]) + let s2 = Series::new(PlSmallStr::from_static("2"), vec!["a", "x", "y"]) .cast(&DataType::Categorical(None, Default::default())) .unwrap(); let appended = s1.append(&s2).unwrap(); @@ -495,13 +499,13 @@ mod test { #[test] fn test_fast_unique() { let _lock = SINGLE_LOCK.lock(); - let s = Series::new("1", vec!["a", "b", "c"]) + let s = Series::new(PlSmallStr::from_static("1"), vec!["a", "b", "c"]) .cast(&DataType::Categorical(None, Default::default())) .unwrap(); assert_eq!(s.n_unique().unwrap(), 3); // Make sure that it does not take the fast path after take/slice. - let out = s.take(&IdxCa::new("", [1, 2])).unwrap(); + let out = s.take(&IdxCa::new(PlSmallStr::EMPTY, [1, 2])).unwrap(); assert_eq!(out.n_unique().unwrap(), 2); let out = s.slice(1, 2); assert_eq!(out.n_unique().unwrap(), 2); @@ -513,12 +517,15 @@ mod test { disable_string_cache(); // tests several things that may lose the dtype information - let s = Series::new("a", vec!["a", "b", "c"]) + let s = Series::new(PlSmallStr::from_static("a"), vec!["a", "b", "c"]) .cast(&DataType::Categorical(None, Default::default()))?; assert_eq!( s.field().into_owned(), - Field::new("a", DataType::Categorical(None, Default::default())) + Field::new( + PlSmallStr::from_static("a"), + DataType::Categorical(None, Default::default()) + ) ); assert!(matches!( s.get(0)?, diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs index 3404a22bfdff..884a72916628 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs @@ -8,7 +8,7 @@ struct CategoricalAppend; impl CategoricalMergeOperation for CategoricalAppend { fn finish(self, lhs: &UInt32Chunked, rhs: &UInt32Chunked) -> PolarsResult { let mut lhs_mut = lhs.clone(); - lhs_mut.append(rhs); + lhs_mut.append(rhs)?; Ok(lhs_mut) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs index 959717155ce3..ed53722d163e 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/full.rs @@ -2,7 +2,7 @@ use super::*; impl CategoricalChunked { pub fn full_null( - name: &str, + name: PlSmallStr, is_enum: bool, length: usize, ordering: CategoricalOrdering, diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index 21a4bfb96a6a..a0f4a4ef90db 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -5,12 +5,14 @@ impl CategoricalChunked { let cat_map = self.get_rev_map(); if self._can_fast_unique() { let ca = match &**cat_map { - RevMapping::Local(a, _) => { - UInt32Chunked::from_iter_values(self.physical().name(), 0..(a.len() as u32)) - }, - RevMapping::Global(map, _, _) => { - UInt32Chunked::from_iter_values(self.physical().name(), map.keys().copied()) - }, + RevMapping::Local(a, _) => UInt32Chunked::from_iter_values( + self.physical().name().clone(), + 0..(a.len() as u32), + ), + RevMapping::Global(map, _, _) => UInt32Chunked::from_iter_values( + self.physical().name().clone(), + map.keys().copied(), + ), }; // SAFETY: // we only removed some indexes so we are still in bounds @@ -63,7 +65,7 @@ impl CategoricalChunked { *values.physical_mut() = physical_values; let mut counts = groups.group_count(); - counts.rename("counts"); + counts.rename(PlSmallStr::from_static("counts")); let cols = vec![values.into_series(), counts.into_series()]; let df = unsafe { DataFrame::new_no_checks(cols) }; df.sort( diff --git a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs index 7f8a97e9406c..0458fa2f36be 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/revmap.rs @@ -1,8 +1,8 @@ use std::fmt::{Debug, Formatter}; use std::hash::{BuildHasher, Hash, Hasher}; -use ahash::RandomState; use arrow::array::*; +use polars_utils::aliases::PlRandomState; #[cfg(any(feature = "serde-lazy", feature = "serde"))] use serde::{Deserialize, Serialize}; @@ -76,12 +76,12 @@ impl RevMapping { fn build_hash(categories: &Utf8ViewArray) -> u128 { // TODO! we must also validate the cases of duplicates! - let mut hb = RandomState::with_seed(0).build_hasher(); + let mut hb = PlRandomState::with_seed(0).build_hasher(); categories.values_iter().for_each(|val| { val.hash(&mut hb); }); let hash = hb.finish(); - (hash as u128) << 64 | (categories.total_bytes_len() as u128) + (hash as u128) << 64 | (categories.total_buffer_len() as u128) } pub fn build_local(categories: Utf8ViewArray) -> Self { diff --git a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs index 91cb2140443e..a0bd2687af63 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs @@ -2,10 +2,10 @@ use std::hash::{Hash, Hasher}; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; -use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; use once_cell::sync::Lazy; -use smartstring::{LazyCompact, SmartString}; +use polars_utils::aliases::PlRandomState; +use polars_utils::pl_str::PlSmallStr; use crate::datatypes::{InitHashMaps2, PlIdHashMap}; use crate::hashing::_HASHMAP_INIT_SIZE; @@ -133,7 +133,7 @@ impl Hash for Key { pub(crate) struct SCacheInner { map: PlIdHashMap, pub(crate) uuid: u32, - payloads: Vec, + payloads: Vec, } impl SCacheInner { @@ -149,8 +149,8 @@ impl SCacheInner { #[inline] pub(crate) fn insert_from_hash(&mut self, h: u64, s: &str) -> u32 { let mut global_idx = self.payloads.len() as u32; - // Note that we don't create the StrHashGlobal to search the key in the hashmap - // as StrHashGlobal may allocate a string + // Note that we don't create the PlSmallStr to search the key in the hashmap + // as PlSmallStr may allocate a string let entry = self.map.raw_entry_mut().from_hash(h, |key| { (key.hash == h) && { let pos = key.idx as usize; @@ -169,7 +169,7 @@ impl SCacheInner { entry.insert_hashed_nocheck(h, key, ()); // only just now we allocate the string - self.payloads.push(s.into()); + self.payloads.push(PlSmallStr::from_str(s)); }, } global_idx @@ -178,7 +178,6 @@ impl SCacheInner { #[inline] pub(crate) fn get_cat(&self, s: &str) -> Option { let h = StringCache::get_hash_builder().hash_one(s); - // as StrHashGlobal may allocate a string self.map .raw_entry() .from_hash(h, |key| { @@ -219,8 +218,8 @@ impl StringCache { /// The global `StringCache` will always use a predictable seed. This allows local builders to mimic /// the hashes in case of contention. #[inline] - pub(crate) fn get_hash_builder() -> RandomState { - RandomState::with_seed(0) + pub(crate) fn get_hash_builder() -> PlRandomState { + PlRandomState::with_seed(0) } /// Lock the string cache @@ -254,5 +253,3 @@ impl StringCache { } pub(crate) static STRING_CACHE: Lazy = Lazy::new(Default::default); - -type StrHashGlobal = SmartString; diff --git a/crates/polars-core/src/chunked_array/logical/decimal.rs b/crates/polars-core/src/chunked_array/logical/decimal.rs index b8bd978fda0e..f723bc3b7e70 100644 --- a/crates/polars-core/src/chunked_array/logical/decimal.rs +++ b/crates/polars-core/src/chunked_array/logical/decimal.rs @@ -11,7 +11,7 @@ impl Int128Chunked { // physical i128 type doesn't exist // so we update the decimal dtype for arr in self.chunks.iter_mut() { - let mut default = PrimitiveArray::new_empty(arr.data_type().clone()); + let mut default = PrimitiveArray::new_empty(arr.dtype().clone()); let arr = arr .as_any_mut() .downcast_mut::>() @@ -104,7 +104,7 @@ impl LogicalType for DecimalChunked { let chunks = cast_chunks(&self.chunks, dtype, cast_options)?; unsafe { Ok(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), chunks, dtype, )) @@ -134,7 +134,8 @@ impl DecimalChunked { let dtype = DataType::Decimal(None, Some(scale)); let chunks = cast_chunks(&self.chunks, &dtype, CastOptions::NonStrict)?; - let mut dt = Self::new_logical(unsafe { Int128Chunked::from_chunks(self.name(), chunks) }); + let mut dt = + Self::new_logical(unsafe { Int128Chunked::from_chunks(self.name().clone(), chunks) }); dt.2 = Some(dtype); Ok(Cow::Owned(dt)) } diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index 8ed6dc4dae35..d4e6d4eb84aa 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -97,6 +97,6 @@ where } pub fn field(&self) -> Field { let name = self.0.ref_field().name(); - Field::new(name, LogicalType::dtype(self).clone()) + Field::new(name.clone(), LogicalType::dtype(self).clone()) } } diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 71625393ec51..c59b520bf8e8 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -512,7 +512,7 @@ impl ChunkedArray { // SAFETY: we keep the correct dtype let mut ca = unsafe { self.copy_with_chunks(vec![new_empty_array( - self.chunks.first().unwrap().data_type().clone(), + self.chunks.first().unwrap().dtype().clone(), )]) }; @@ -599,15 +599,15 @@ impl ChunkedArray { /// Get data type of [`ChunkedArray`]. pub fn dtype(&self) -> &DataType { - self.field.data_type() + self.field.dtype() } pub(crate) unsafe fn set_dtype(&mut self, dtype: DataType) { - self.field = Arc::new(Field::new(self.name(), dtype)) + self.field = Arc::new(Field::new(self.name().clone(), dtype)) } /// Name of the [`ChunkedArray`]. - pub fn name(&self) -> &str { + pub fn name(&self) -> &PlSmallStr { self.field.name() } @@ -617,12 +617,12 @@ impl ChunkedArray { } /// Rename this [`ChunkedArray`]. - pub fn rename(&mut self, name: &str) { - self.field = Arc::new(Field::new(name, self.field.data_type().clone())) + pub fn rename(&mut self, name: PlSmallStr) { + self.field = Arc::new(Field::new(name, self.field.dtype().clone())) } /// Return this [`ChunkedArray`] with a new name. - pub fn with_name(mut self, name: &str) -> Self { + pub fn with_name(mut self, name: PlSmallStr) -> Self { self.rename(name); self } @@ -690,6 +690,14 @@ where } } + #[inline] + pub fn first(&self) -> Option> { + unsafe { + let arr = self.downcast_get_unchecked(0); + arr.get_unchecked(0) + } + } + #[inline] pub fn last(&self) -> Option> { unsafe { @@ -704,7 +712,7 @@ impl ListChunked { pub fn get_as_series(&self, idx: usize) -> Option { unsafe { Some(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![self.get(idx)?], &self.inner_dtype().to_physical(), )) @@ -718,7 +726,7 @@ impl ArrayChunked { pub fn get_as_series(&self, idx: usize) -> Option { unsafe { Some(Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![self.get(idx)?], &self.inner_dtype().to_physical(), )) @@ -754,7 +762,9 @@ where .collect(); // SAFETY: We just slice the original chunks, their type will not change. - unsafe { Self::from_chunks_and_dtype(self.name(), chunks, self.dtype().clone()) } + unsafe { + Self::from_chunks_and_dtype(self.name().clone(), chunks, self.dtype().clone()) + } }; if self.chunks.len() != 1 { @@ -948,9 +958,12 @@ pub(crate) fn to_array( impl Default for ChunkedArray { fn default() -> Self { + let dtype = T::get_dtype(); + let arrow_dtype = dtype.to_physical().to_arrow(CompatLevel::newest()); ChunkedArray { - field: Arc::new(Field::new("default", DataType::Null)), - chunks: Default::default(), + field: Arc::new(Field::new(PlSmallStr::EMPTY, dtype)), + // Invariant: always has 1 chunk. + chunks: vec![new_empty_array(arrow_dtype)], md: Arc::new(IMMetadata::default()), length: 0, null_count: 0, @@ -963,19 +976,19 @@ pub(crate) mod test { use crate::prelude::*; pub(crate) fn get_chunked_array() -> Int32Chunked { - ChunkedArray::new("a", &[1, 2, 3]) + ChunkedArray::new(PlSmallStr::from_static("a"), &[1, 2, 3]) } #[test] fn test_sort() { - let a = Int32Chunked::new("a", &[1, 9, 3, 2]); + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 9, 3, 2]); let b = a .sort(false) .into_iter() .map(|opt| opt.unwrap()) .collect::>(); assert_eq!(b, [1, 2, 3, 9]); - let a = StringChunked::new("a", &["b", "a", "c"]); + let a = StringChunked::new(PlSmallStr::from_static("a"), &["b", "a", "c"]); let a = a.sort(false); let b = a.into_iter().collect::>(); assert_eq!(b, [Some("a"), Some("b"), Some("c")]); @@ -984,8 +997,8 @@ pub(crate) mod test { #[test] fn arithmetic() { - let a = &Int32Chunked::new("a", &[1, 100, 6, 40]); - let b = &Int32Chunked::new("b", &[-1, 2, 3, 4]); + let a = &Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 100, 6, 40]); + let b = &Int32Chunked::new(PlSmallStr::from_static("b"), &[-1, 2, 3, 4]); // Not really asserting anything here but still making sure the code is exercised // This (and more) is properly tested from the integration test suite and Python bindings. @@ -1014,7 +1027,10 @@ pub(crate) mod test { fn filter() { let a = get_chunked_array(); let b = a - .filter(&BooleanChunked::new("filter", &[true, false, false])) + .filter(&BooleanChunked::new( + PlSmallStr::from_static("filter"), + &[true, false, false], + )) .unwrap(); assert_eq!(b.len(), 1); assert_eq!(b.into_iter().next(), Some(Some(1))); @@ -1022,7 +1038,7 @@ pub(crate) mod test { #[test] fn aggregates() { - let a = &Int32Chunked::new("a", &[1, 100, 10, 9]); + let a = &Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 100, 10, 9]); assert_eq!(a.max(), Some(100)); assert_eq!(a.min(), Some(1)); assert_eq!(a.sum(), Some(120)) @@ -1051,9 +1067,9 @@ pub(crate) mod test { #[test] fn slice() { - let mut first = UInt32Chunked::new("first", &[0, 1, 2]); - let second = UInt32Chunked::new("second", &[3, 4, 5]); - first.append(&second); + let mut first = UInt32Chunked::new(PlSmallStr::from_static("first"), &[0, 1, 2]); + let second = UInt32Chunked::new(PlSmallStr::from_static("second"), &[3, 4, 5]); + first.append(&second).unwrap(); assert_slice_equal(&first.slice(0, 3), &[0, 1, 2]); assert_slice_equal(&first.slice(0, 4), &[0, 1, 2, 3]); assert_slice_equal(&first.slice(1, 4), &[1, 2, 3, 4]); @@ -1070,7 +1086,7 @@ pub(crate) mod test { #[test] fn sorting() { - let s = UInt32Chunked::new("", &[9, 2, 4]); + let s = UInt32Chunked::new(PlSmallStr::EMPTY, &[9, 2, 4]); let sorted = s.sort(false); assert_slice_equal(&sorted, &[2, 4, 9]); let sorted = s.sort(true); @@ -1097,19 +1113,19 @@ pub(crate) mod test { #[test] fn reverse() { - let s = UInt32Chunked::new("", &[1, 2, 3]); + let s = UInt32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3]); // path with continuous slice assert_slice_equal(&s.reverse(), &[3, 2, 1]); // path with options - let s = UInt32Chunked::new("", &[Some(1), None, Some(3)]); + let s = UInt32Chunked::new(PlSmallStr::EMPTY, &[Some(1), None, Some(3)]); assert_eq!(Vec::from(&s.reverse()), &[Some(3), None, Some(1)]); - let s = BooleanChunked::new("", &[true, false]); + let s = BooleanChunked::new(PlSmallStr::EMPTY, &[true, false]); assert_eq!(Vec::from(&s.reverse()), &[Some(false), Some(true)]); - let s = StringChunked::new("", &["a", "b", "c"]); + let s = StringChunked::new(PlSmallStr::EMPTY, &["a", "b", "c"]); assert_eq!(Vec::from(&s.reverse()), &[Some("c"), Some("b"), Some("a")]); - let s = StringChunked::new("", &[Some("a"), None, Some("c")]); + let s = StringChunked::new(PlSmallStr::EMPTY, &[Some("a"), None, Some("c")]); assert_eq!(Vec::from(&s.reverse()), &[Some("c"), None, Some("a")]); } @@ -1119,7 +1135,10 @@ pub(crate) mod test { use crate::{disable_string_cache, SINGLE_LOCK}; let _lock = SINGLE_LOCK.lock(); disable_string_cache(); - let ca = StringChunked::new("", &[Some("foo"), None, Some("bar"), Some("ham")]); + let ca = StringChunked::new( + PlSmallStr::EMPTY, + &[Some("foo"), None, Some("bar"), Some("ham")], + ); let ca = ca .cast(&DataType::Categorical(None, Default::default())) .unwrap(); @@ -1131,7 +1150,7 @@ pub(crate) mod test { #[test] #[ignore] fn test_shrink_to_fit() { - let mut builder = StringChunkedBuilder::new("foo", 2048); + let mut builder = StringChunkedBuilder::new(PlSmallStr::from_static("foo"), 2048); builder.append_value("foo"); let mut arr = builder.finish(); let before = arr diff --git a/crates/polars-core/src/chunked_array/ndarray.rs b/crates/polars-core/src/chunked_array/ndarray.rs index 9bff6d03f411..079061e31478 100644 --- a/crates/polars-core/src/chunked_array/ndarray.rs +++ b/crates/polars-core/src/chunked_array/ndarray.rs @@ -83,8 +83,8 @@ impl DataFrame { /// /// ```rust /// use polars_core::prelude::*; - /// let a = UInt32Chunked::new("a", &[1, 2, 3]).into_series(); - /// let b = Float64Chunked::new("b", &[10., 8., 6.]).into_series(); + /// let a = UInt32Chunked::new("a".into(), &[1, 2, 3]).into_series(); + /// let b = Float64Chunked::new("b".into(), &[10., 8., 6.]).into_series(); /// /// let df = DataFrame::new(vec![a, b]).unwrap(); /// let ndarray = df.to_ndarray::(IndexOrder::Fortran).unwrap(); @@ -186,12 +186,16 @@ mod test { #[test] fn test_ndarray_from_ca() -> PolarsResult<()> { - let ca = Float64Chunked::new("", &[1.0, 2.0, 3.0]); + let ca = Float64Chunked::new(PlSmallStr::EMPTY, &[1.0, 2.0, 3.0]); let ndarr = ca.to_ndarray()?; assert_eq!(ndarr, ArrayView1::from(&[1.0, 2.0, 3.0])); - let mut builder = - ListPrimitiveChunkedBuilder::::new("", 10, 10, DataType::Float64); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::EMPTY, + 10, + 10, + DataType::Float64, + ); builder.append_opt_slice(Some(&[1.0, 2.0, 3.0])); builder.append_opt_slice(Some(&[2.0, 4.0, 5.0])); builder.append_opt_slice(Some(&[6.0, 7.0, 8.0])); @@ -202,8 +206,12 @@ mod test { assert_eq!(ndarr, expected); // test list array that is not square - let mut builder = - ListPrimitiveChunkedBuilder::::new("", 10, 10, DataType::Float64); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::EMPTY, + 10, + 10, + DataType::Float64, + ); builder.append_opt_slice(Some(&[1.0, 2.0, 3.0])); builder.append_opt_slice(Some(&[2.0])); builder.append_opt_slice(Some(&[6.0, 7.0, 8.0])); diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index d54fb8c7dea6..01524c018ec2 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -12,7 +12,7 @@ impl ObjectChunkedBuilder where T: PolarsObject, { - pub fn new(name: &str, capacity: usize) -> Self { + pub fn new(name: PlSmallStr, capacity: usize) -> Self { ObjectChunkedBuilder { field: Field::new(name, DataType::Object(T::type_name(), None)), values: Vec::with_capacity(capacity), @@ -78,7 +78,7 @@ where /// Initialize a polars Object data type. The type has got information needed to /// construct new objects. pub(crate) fn get_object_type() -> DataType { - let object_builder = Box::new(|name: &str, capacity: usize| { + let object_builder = Box::new(|name: PlSmallStr, capacity: usize| { Box::new(ObjectChunkedBuilder::::new(name, capacity)) as Box }); @@ -94,7 +94,7 @@ where T: PolarsObject, { fn default() -> Self { - ObjectChunkedBuilder::new("", 0) + ObjectChunkedBuilder::new(PlSmallStr::EMPTY, 0) } } @@ -102,11 +102,11 @@ impl NewChunkedArray, T> for ObjectChunked where T: PolarsObject, { - fn from_slice(name: &str, v: &[T]) -> Self { + fn from_slice(name: PlSmallStr, v: &[T]) -> Self { Self::from_iter_values(name, v.iter().cloned()) } - fn from_slice_options(name: &str, opt_v: &[Option]) -> Self { + fn from_slice_options(name: PlSmallStr, opt_v: &[Option]) -> Self { let mut builder = ObjectChunkedBuilder::::new(name, opt_v.len()); opt_v .iter() @@ -115,14 +115,17 @@ where builder.finish() } - fn from_iter_options(name: &str, it: impl Iterator>) -> ObjectChunked { + fn from_iter_options( + name: PlSmallStr, + it: impl Iterator>, + ) -> ObjectChunked { let mut builder = ObjectChunkedBuilder::new(name, get_iter_capacity(&it)); it.for_each(|opt| builder.append_option(opt)); builder.finish() } /// Create a new ChunkedArray from an iterator. - fn from_iter_values(name: &str, it: impl Iterator) -> ObjectChunked { + fn from_iter_values(name: PlSmallStr, it: impl Iterator) -> ObjectChunked { let mut builder = ObjectChunkedBuilder::new(name, get_iter_capacity(&it)); it.for_each(|v| builder.append_value(v)); builder.finish() @@ -133,7 +136,7 @@ impl ObjectChunked where T: PolarsObject, { - pub fn new_from_vec(name: &str, v: Vec) -> Self { + pub fn new_from_vec(name: PlSmallStr, v: Vec) -> Self { let field = Arc::new(Field::new(name, DataType::Object(T::type_name(), None))); let len = v.len(); let arr = Box::new(ObjectArray { @@ -146,7 +149,7 @@ where unsafe { ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0) } } - pub fn new_from_vec_and_validity(name: &str, v: Vec, validity: Bitmap) -> Self { + pub fn new_from_vec_and_validity(name: PlSmallStr, v: Vec, validity: Bitmap) -> Self { let field = Arc::new(Field::new(name, DataType::Object(T::type_name(), None))); let len = v.len(); let null_count = validity.unset_bits(); @@ -162,7 +165,7 @@ where } } - pub fn new_empty(name: &str) -> Self { + pub fn new_empty(name: PlSmallStr) -> Self { Self::new_from_vec(name, vec![]) } } diff --git a/crates/polars-core/src/chunked_array/object/extension/drop.rs b/crates/polars-core/src/chunked_array/object/extension/drop.rs index a4374cd68d78..3b3e16deff2e 100644 --- a/crates/polars-core/src/chunked_array/object/extension/drop.rs +++ b/crates/polars-core/src/chunked_array/object/extension/drop.rs @@ -18,8 +18,8 @@ pub(crate) unsafe fn drop_list(ca: &ListChunked) { // if empty the memory is leaked somewhere assert!(!ca.chunks.is_empty()); for lst_arr in &ca.chunks { - if let ArrowDataType::LargeList(fld) = lst_arr.data_type() { - let dtype = fld.data_type(); + if let ArrowDataType::LargeList(fld) = lst_arr.dtype() { + let dtype = fld.dtype(); assert!(matches!(dtype, ArrowDataType::Extension(_, _, _))); @@ -39,10 +39,9 @@ pub(crate) unsafe fn drop_object_array(values: &dyn Array) { .downcast_ref::() .unwrap(); - // if the buf is not shared with anyone but us - // we can deallocate + // If the buf is not shared with anyone but us we can deallocate. let buf = arr.values(); - if buf.shared_count_strong() == 1 { + if buf.shared_count_strong() == 1 && !buf.is_empty() { PolarsExtension::new(arr.clone()); }; } diff --git a/crates/polars-core/src/chunked_array/object/extension/list.rs b/crates/polars-core/src/chunked_array/object/extension/list.rs index fb4ea6d73a2c..1918039d647e 100644 --- a/crates/polars-core/src/chunked_array/object/extension/list.rs +++ b/crates/polars-core/src/chunked_array/object/extension/list.rs @@ -6,7 +6,7 @@ use crate::prelude::*; impl ObjectChunked { pub(crate) fn get_list_builder( - name: &str, + name: PlSmallStr, values_capacity: usize, list_capacity: usize, ) -> Box { @@ -25,7 +25,7 @@ struct ExtensionListBuilder { } impl ExtensionListBuilder { - pub(crate) fn new(name: &str, values_capacity: usize, list_capacity: usize) -> Self { + pub(crate) fn new(name: PlSmallStr, values_capacity: usize, list_capacity: usize) -> Self { let mut offsets = Vec::with_capacity(list_capacity + 1); offsets.push(0); Self { @@ -69,18 +69,18 @@ impl ListBuilderTrait for ExtensionListBuilder { let mut pe = create_extension(obj_arr.into_iter_cloned()); unsafe { pe.set_to_series_fn::() }; let extension_array = Box::new(pe.take_and_forget()) as ArrayRef; - let extension_dtype = extension_array.data_type(); + let extension_dtype = extension_array.dtype(); - let data_type = ListArray::::default_datatype(extension_dtype.clone()); + let dtype = ListArray::::default_datatype(extension_dtype.clone()); let arr = ListArray::::new( - data_type, + dtype, // SAFETY: offsets are monotonically increasing. unsafe { Offsets::new_unchecked(offsets).into() }, extension_array, None, ); - let mut listarr = ListChunked::with_chunk(ca.name(), arr); + let mut listarr = ListChunked::with_chunk(ca.name().clone(), arr); if self.fast_explode { listarr.set_fast_explode() } diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 51916297de91..5a049da4a01f 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -9,6 +9,7 @@ use arrow::array::FixedSizeBinaryArray; use arrow::bitmap::MutableBitmap; use arrow::buffer::Buffer; use polars_extension::PolarsExtension; +use polars_utils::format_pl_smallstr; use crate::prelude::*; use crate::PROCESS_ID; @@ -39,9 +40,9 @@ unsafe fn create_drop(mut ptr: *const u8, n_t_vals: usize) -> Box>, - // A function on the heap that take a `array: FixedSizeBinary` and a `name: &str` + // A function on the heap that take a `array: FixedSizeBinary` and a `name: PlSmallStr` // and returns a `Series` of `ObjectChunked` - pub(crate) to_series_fn: Option Series>>, + pub(crate) to_series_fn: Option Series>>, } impl Drop for ExtensionSentinel { @@ -120,11 +121,14 @@ pub(crate) fn create_extension> + TrustedLen, T: Si let et_ptr = &*et as *const ExtensionSentinel; std::mem::forget(et); - let metadata = format!("{};{}", *PROCESS_ID, et_ptr as usize); + let metadata = format_pl_smallstr!("{};{}", *PROCESS_ID, et_ptr as usize); let physical_type = ArrowDataType::FixedSizeBinary(t_size); - let extension_type = - ArrowDataType::Extension(EXTENSION_NAME.into(), physical_type.into(), Some(metadata)); + let extension_type = ArrowDataType::Extension( + PlSmallStr::from_static(EXTENSION_NAME), + physical_type.into(), + Some(metadata), + ); // first freeze, otherwise we compute null let validity = if null_count > 0 { Some(validity.into()) @@ -217,7 +221,7 @@ mod test { }; let values = &[Some(foo1), None, Some(foo2), None]; - let ca = ObjectChunked::new("", values); + let ca = ObjectChunked::new(PlSmallStr::EMPTY, values); let groups = GroupsProxy::Idx(vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into()); @@ -241,7 +245,7 @@ mod test { }; let values = &[Some(foo1.clone()), None, Some(foo2.clone()), None]; - let ca = ObjectChunked::new("", values); + let ca = ObjectChunked::new(PlSmallStr::EMPTY, values); let groups = vec![(0, unitvec![0, 1]), (2, unitvec![2]), (3, unitvec![3])].into(); let out = unsafe { ca.agg_list(&GroupsProxy::Idx(groups)) }; diff --git a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs index 6030f668dfe1..424c8aaccf6c 100644 --- a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs +++ b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs @@ -13,7 +13,11 @@ impl PolarsExtension { let arr = arr.slice_typed_unchecked(i, 1); let pe = Self::new(arr); let pe = ManuallyDrop::new(pe); - pe.get_series("").get(0).unwrap().into_static().unwrap() + pe.get_series(&PlSmallStr::EMPTY) + .get(0) + .unwrap() + .into_static() + .unwrap() } pub(crate) unsafe fn new(array: FixedSizeBinaryArray) -> Self { @@ -38,8 +42,7 @@ impl PolarsExtension { /// Load the sentinel from the heap. /// be very careful, this dereferences a raw pointer on the heap, unsafe fn get_sentinel(&self) -> Box { - if let ArrowDataType::Extension(_, _, Some(metadata)) = - self.array.as_ref().unwrap().data_type() + if let ArrowDataType::Extension(_, _, Some(metadata)) = self.array.as_ref().unwrap().dtype() { let mut iter = metadata.split(';'); @@ -57,7 +60,7 @@ impl PolarsExtension { /// Calls the heap allocated function in the `[ExtensionSentinel]` that knows /// how to convert the `[FixedSizeBinaryArray]` to a `Series` of type `[ObjectChunked]` - pub(crate) unsafe fn get_series(&self, name: &str) -> Series { + pub(crate) unsafe fn get_series(&self, name: &PlSmallStr) -> Series { self.with_sentinel(|sent| { (sent.to_series_fn.as_ref().unwrap())(self.array.as_ref().unwrap(), name) }) @@ -66,7 +69,7 @@ impl PolarsExtension { // heap allocates a function that converts the binary array to a Series of `[ObjectChunked]` // the `name` will be the `name` of the output `Series` when this function is called (later). pub(crate) unsafe fn set_to_series_fn(&mut self) { - let f = Box::new(move |arr: &FixedSizeBinaryArray, name: &str| { + let f = Box::new(move |arr: &FixedSizeBinaryArray, name: &PlSmallStr| { let iter = arr.iter().map(|opt| { opt.map(|bytes| { let t = std::ptr::read_unaligned(bytes.as_ptr() as *const T); @@ -77,7 +80,7 @@ impl PolarsExtension { }) }); - let ca = ObjectChunked::::from_iter_options(name, iter); + let ca = ObjectChunked::::from_iter_options(name.clone(), iter); ca.into_series() }); self.with_sentinel(move |sent| { diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index fffe547f6d0c..1b018800dd98 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -169,7 +169,7 @@ where self } - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { &ArrowDataType::FixedSizeBinary(std::mem::size_of::()) } diff --git a/crates/polars-core/src/chunked_array/object/registry.rs b/crates/polars-core/src/chunked_array/object/registry.rs index 5ebcad2a022a..e84c7ab69ba5 100644 --- a/crates/polars-core/src/chunked_array/object/registry.rs +++ b/crates/polars-core/src/chunked_array/object/registry.rs @@ -1,4 +1,5 @@ //! This is a heap allocated utility that can be used to register an object type. +//! //! That object type will know its own generic type parameter `T` and callers can simply //! send `&Any` values and don't have to know the generic type themselves. use std::any::Any; @@ -8,6 +9,7 @@ use std::sync::{Arc, RwLock}; use arrow::datatypes::ArrowDataType; use once_cell::sync::Lazy; +use polars_utils::pl_str::PlSmallStr; use crate::chunked_array::object::builder::ObjectChunkedBuilder; use crate::datatypes::AnyValue; @@ -16,7 +18,7 @@ use crate::series::{IntoSeries, Series}; /// Takes a `name` and `capacity` and constructs a new builder. pub type BuilderConstructor = - Box Box + Send + Sync>; + Box Box + Send + Sync>; pub type ObjectConverter = Arc Box + Send + Sync>; pub struct ObjectRegistry { @@ -115,7 +117,7 @@ pub fn get_object_physical_type() -> ArrowDataType { reg.physical_dtype.clone() } -pub fn get_object_builder(name: &str, capacity: usize) -> Box { +pub fn get_object_builder(name: PlSmallStr, capacity: usize) -> Box { let reg = GLOBAL_OBJECT_REGISTRY.read().unwrap(); let reg = reg.as_ref().unwrap(); (reg.builder_constructor)(name, capacity) diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index b6848d80b652..b94d724a5185 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -88,6 +88,10 @@ where ) } + fn _sum_as_f64(&self) -> f64 { + self.downcast_iter().map(float_sum::sum_arr_as_f64).sum() + } + fn min(&self) -> Option { if self.null_count() == self.len() { return None; @@ -216,13 +220,11 @@ where } fn mean(&self) -> Option { - if self.null_count() == self.len() { + let count = self.len() - self.null_count(); + if count == 0 { return None; } - - let len = (self.len() - self.null_count()) as f64; - let sum: f64 = self.downcast_iter().map(float_sum::sum_arr_as_f64).sum(); - Some(sum / len) + Some(self._sum_as_f64() / count as f64) } } @@ -636,9 +638,9 @@ mod test { // Validated with numpy. Note that numpy uses ddof as an argument which // influences results. The default ddof=0, we chose ddof=1, which is // standard in statistics. - let ca1 = Int32Chunked::new("", &[5, 8, 9, 5, 0]); + let ca1 = Int32Chunked::new(PlSmallStr::EMPTY, &[5, 8, 9, 5, 0]); let ca2 = Int32Chunked::new( - "", + PlSmallStr::EMPTY, &[ Some(5), None, @@ -660,11 +662,11 @@ mod test { #[test] fn test_agg_float() { - let ca1 = Float32Chunked::new("a", &[1.0, f32::NAN]); - let ca2 = Float32Chunked::new("b", &[f32::NAN, 1.0]); + let ca1 = Float32Chunked::new(PlSmallStr::from_static("a"), &[1.0, f32::NAN]); + let ca2 = Float32Chunked::new(PlSmallStr::from_static("b"), &[f32::NAN, 1.0]); assert_eq!(ca1.min(), ca2.min()); - let ca1 = Float64Chunked::new("a", &[1.0, f64::NAN]); - let ca2 = Float64Chunked::from_slice("b", &[f64::NAN, 1.0]); + let ca1 = Float64Chunked::new(PlSmallStr::from_static("a"), &[1.0, f64::NAN]); + let ca2 = Float64Chunked::from_slice(PlSmallStr::from_static("b"), &[f64::NAN, 1.0]); assert_eq!(ca1.min(), ca2.min()); println!("{:?}", (ca1.min(), ca2.min())) } @@ -672,12 +674,12 @@ mod test { #[test] fn test_median() { let ca = UInt32Chunked::new( - "a", + PlSmallStr::from_static("a"), &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)], ); assert_eq!(ca.median(), Some(3.0)); let ca = UInt32Chunked::new( - "a", + PlSmallStr::from_static("a"), &[ None, Some(7), @@ -694,7 +696,7 @@ mod test { assert_eq!(ca.median(), Some(4.0)); let ca = Float32Chunked::from_slice( - "", + PlSmallStr::EMPTY, &[ 0.166189, 0.166559, 0.168517, 0.169393, 0.175272, 0.233167, 0.238787, 0.266562, 0.26903, 0.285792, 0.292801, 0.293429, 0.301706, 0.308534, 0.331489, 0.346095, @@ -707,7 +709,7 @@ mod test { #[test] fn test_mean() { - let ca = Float32Chunked::new("", &[Some(1.0), Some(2.0), None]); + let ca = Float32Chunked::new(PlSmallStr::EMPTY, &[Some(1.0), Some(2.0), None]); assert_eq!(ca.mean().unwrap(), 1.5); assert_eq!( ca.into_series() @@ -718,7 +720,7 @@ mod test { 1.5 ); // all null values case - let ca = Float32Chunked::full_null("", 3); + let ca = Float32Chunked::full_null(PlSmallStr::EMPTY, 3); assert_eq!(ca.mean(), None); assert_eq!( ca.into_series().mean_reduce().value().extract::(), @@ -728,10 +730,10 @@ mod test { #[test] fn test_quantile_all_null() { - let test_f32 = Float32Chunked::from_slice_options("", &[None, None, None]); - let test_i32 = Int32Chunked::from_slice_options("", &[None, None, None]); - let test_f64 = Float64Chunked::from_slice_options("", &[None, None, None]); - let test_i64 = Int64Chunked::from_slice_options("", &[None, None, None]); + let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); + let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[None, None, None]); let interpol_options = vec![ QuantileInterpolOptions::Nearest, @@ -751,10 +753,10 @@ mod test { #[test] fn test_quantile_single_value() { - let test_f32 = Float32Chunked::from_slice_options("", &[Some(1.0)]); - let test_i32 = Int32Chunked::from_slice_options("", &[Some(1)]); - let test_f64 = Float64Chunked::from_slice_options("", &[Some(1.0)]); - let test_i64 = Int64Chunked::from_slice_options("", &[Some(1)]); + let test_f32 = Float32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); + let test_i32 = Int32Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); + let test_f64 = Float64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1.0)]); + let test_i64 = Int64Chunked::from_slice_options(PlSmallStr::EMPTY, &[Some(1)]); let interpol_options = vec![ QuantileInterpolOptions::Nearest, @@ -774,14 +776,22 @@ mod test { #[test] fn test_quantile_min_max() { - let test_f32 = - Float32Chunked::from_slice_options("", &[None, Some(1f32), Some(5f32), Some(1f32)]); - let test_i32 = - Int32Chunked::from_slice_options("", &[None, Some(1i32), Some(5i32), Some(1i32)]); - let test_f64 = - Float64Chunked::from_slice_options("", &[None, Some(1f64), Some(5f64), Some(1f64)]); - let test_i64 = - Int64Chunked::from_slice_options("", &[None, Some(1i64), Some(5i64), Some(1i64)]); + let test_f32 = Float32Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1f32), Some(5f32), Some(1f32)], + ); + let test_i32 = Int32Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1i32), Some(5i32), Some(1i32)], + ); + let test_f64 = Float64Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1f64), Some(5f64), Some(1f64)], + ); + let test_i64 = Int64Chunked::from_slice_options( + PlSmallStr::EMPTY, + &[None, Some(1i64), Some(5i64), Some(1i64)], + ); let interpol_options = vec![ QuantileInterpolOptions::Nearest, @@ -822,7 +832,7 @@ mod test { #[test] fn test_quantile() { let ca = UInt32Chunked::new( - "a", + PlSmallStr::from_static("a"), &[Some(2), Some(1), None, Some(3), Some(5), None, Some(4)], ); @@ -896,7 +906,7 @@ mod test { ); let ca = UInt32Chunked::new( - "a", + PlSmallStr::from_static("a"), &[ None, Some(7), diff --git a/crates/polars-core/src/chunked_array/ops/any_value.rs b/crates/polars-core/src/chunked_array/ops/any_value.rs index f9a959e582a2..2a50b24d9bbf 100644 --- a/crates/polars-core/src/chunked_array/ops/any_value.rs +++ b/crates/polars-core/src/chunked_array/ops/any_value.rs @@ -49,12 +49,16 @@ pub(crate) unsafe fn arr_to_any_value<'a>( DataType::List(dt) => { let v: ArrayRef = downcast!(LargeListArray); if dt.is_primitive() { - let s = Series::from_chunks_and_dtype_unchecked("", vec![v], dt); + let s = Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![v], dt); AnyValue::List(s) } else { - let s = Series::from_chunks_and_dtype_unchecked("", vec![v], &dt.to_physical()) - .cast_unchecked(dt) - .unwrap(); + let s = Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![v], + &dt.to_physical(), + ) + .cast_unchecked(dt) + .unwrap(); AnyValue::List(s) } }, @@ -62,12 +66,16 @@ pub(crate) unsafe fn arr_to_any_value<'a>( DataType::Array(dt, width) => { let v: ArrayRef = downcast!(FixedSizeListArray); if dt.is_primitive() { - let s = Series::from_chunks_and_dtype_unchecked("", vec![v], dt); + let s = Series::from_chunks_and_dtype_unchecked(PlSmallStr::EMPTY, vec![v], dt); AnyValue::Array(s, *width) } else { - let s = Series::from_chunks_and_dtype_unchecked("", vec![v], &dt.to_physical()) - .cast_unchecked(dt) - .unwrap(); + let s = Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![v], + &dt.to_physical(), + ) + .cast_unchecked(dt) + .unwrap(); AnyValue::Array(s, *width) } }, @@ -153,7 +161,7 @@ impl<'a> AnyValue<'a> { if arr.is_valid_unchecked(idx) { let v = arr.value_unchecked(idx); - match fld.data_type() { + match fld.dtype() { DataType::Categorical(Some(rev_map), _) => { AnyValue::Categorical( v, @@ -170,13 +178,13 @@ impl<'a> AnyValue<'a> { AnyValue::Null } } else { - arr_to_any_value(&**arr, idx, fld.data_type()) + arr_to_any_value(&**arr, idx, fld.dtype()) } } #[cfg(not(feature = "dtype-categorical"))] { - arr_to_any_value(&**arr, idx, fld.data_type()) + arr_to_any_value(&**arr, idx, fld.dtype()) } }) } diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 3c8c0fa2518c..383c76d63600 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -1,3 +1,5 @@ +use polars_error::constants::LENGTH_LIMIT_MSG; + use crate::prelude::*; use crate::series::IsSorted; @@ -136,12 +138,16 @@ where /// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`]. /// /// See also [`extend`](Self::extend) for appends to the underlying memory - pub fn append(&mut self, other: &Self) { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { update_sorted_flag_before_append::(self, other); let len = self.len(); - self.length += other.length; + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); + Ok(()) } } @@ -149,10 +155,13 @@ where impl ListChunked { pub fn append(&mut self, other: &Self) -> PolarsResult<()> { let dtype = merge_dtypes(self.dtype(), other.dtype())?; - self.field = Arc::new(Field::new(self.name(), dtype)); + self.field = Arc::new(Field::new(self.name().clone(), dtype)); let len = self.len(); - self.length += other.length; + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); self.set_sorted_flag(IsSorted::Not); @@ -168,11 +177,14 @@ impl ListChunked { impl ArrayChunked { pub fn append(&mut self, other: &Self) -> PolarsResult<()> { let dtype = merge_dtypes(self.dtype(), other.dtype())?; - self.field = Arc::new(Field::new(self.name(), dtype)); + self.field = Arc::new(Field::new(self.name().clone(), dtype)); let len = self.len(); - self.length += other.length; + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); @@ -186,11 +198,14 @@ impl ArrayChunked { impl StructChunked { pub fn append(&mut self, other: &Self) -> PolarsResult<()> { let dtype = merge_dtypes(self.dtype(), other.dtype())?; - self.field = Arc::new(Field::new(self.name(), dtype)); + self.field = Arc::new(Field::new(self.name().clone(), dtype)); let len = self.len(); - self.length += other.length; + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); @@ -202,11 +217,15 @@ impl StructChunked { #[cfg(feature = "object")] #[doc(hidden)] impl ObjectChunked { - pub fn append(&mut self, other: &Self) { + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { let len = self.len(); - self.length += other.length; + self.length = self + .length + .checked_add(other.length) + .ok_or_else(|| polars_err!(ComputeError: LENGTH_LIMIT_MSG))?; self.null_count += other.null_count; self.set_sorted_flag(IsSorted::Not); new_chunks(&mut self.chunks, &other.chunks, len); + Ok(()) } } diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index de62f1eddc7f..da63d7323aa6 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -37,7 +37,7 @@ where } }); - ChunkedArray::from_chunk_iter(self.name(), iter) + ChunkedArray::from_chunk_iter(self.name().clone(), iter) } /// Applies a function only to the non-null elements, propagating nulls. @@ -64,7 +64,7 @@ where Ok(arr) }); - ChunkedArray::try_from_chunk_iter(self.name(), iter) + ChunkedArray::try_from_chunk_iter(self.name().clone(), iter) } pub fn apply_into_string_amortized<'a, F>(&'a self, mut f: F) -> StringChunked @@ -87,7 +87,7 @@ where mutarr.freeze() }) .collect::>(); - ChunkedArray::from_chunk_iter(self.name(), chunks) + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) } pub fn try_apply_into_string_amortized<'a, F, E>(&'a self, mut f: F) -> Result @@ -112,11 +112,11 @@ where Ok(mutarr.freeze()) }) .collect::>(); - ChunkedArray::try_from_chunk_iter(self.name(), chunks) + ChunkedArray::try_from_chunk_iter(self.name().clone(), chunks) } } -fn apply_in_place_impl(name: &str, chunks: Vec, f: F) -> ChunkedArray +fn apply_in_place_impl(name: PlSmallStr, chunks: Vec, f: F) -> ChunkedArray where F: Fn(S::Native) -> S::Native + Copy, S: PolarsNumericType, @@ -170,7 +170,7 @@ impl ChunkedArray { .unwrap(); s.chunks().clone() }; - apply_in_place_impl(self.name(), chunks, f) + apply_in_place_impl(self.name().clone(), chunks, f) } /// Cast a numeric array to another numeric data type and apply a function in place. @@ -180,7 +180,7 @@ impl ChunkedArray { F: Fn(T::Native) -> T::Native + Copy, { let chunks = std::mem::take(&mut self.chunks); - apply_in_place_impl(self.name(), chunks, f) + apply_in_place_impl(self.name().clone(), chunks, f) } } @@ -217,7 +217,7 @@ where let arr: T::Array = slice.iter().copied().map(f).collect_arr(); arr.with_validity(validity.cloned()) }); - ChunkedArray::from_chunk_iter(self.name(), chunks) + ChunkedArray::from_chunk_iter(self.name().clone(), chunks) } fn apply(&'a self, f: F) -> Self @@ -228,7 +228,7 @@ where let iter = arr.into_iter().map(|opt_v| f(opt_v.copied())); PrimitiveArray::::from_trusted_len_iter(iter) }); - Self::from_chunk_iter(self.name(), chunks) + Self::from_chunk_iter(self.name().clone(), chunks) } fn apply_to_slice(&'a self, f: F, slice: &mut [V]) @@ -312,7 +312,7 @@ impl StringChunked { let new = Utf8ViewArray::arr_from_iter(iter); new.with_validity(arr.validity().cloned()) }); - StringChunked::from_chunk_iter(self.name(), chunks) + StringChunked::from_chunk_iter(self.name().clone(), chunks) } } @@ -326,7 +326,7 @@ impl BinaryChunked { let new = BinaryViewArray::arr_from_iter(iter); new.with_validity(arr.validity().cloned()) }); - BinaryChunked::from_chunk_iter(self.name(), chunks) + BinaryChunked::from_chunk_iter(self.name().clone(), chunks) } } @@ -405,7 +405,7 @@ impl<'a> ChunkApply<'a, &'a [u8]> for BinaryChunked { impl ChunkApplyKernel for BooleanChunked { fn apply_kernel(&self, f: &dyn Fn(&BooleanArray) -> ArrayRef) -> Self { let chunks = self.downcast_iter().map(f).collect(); - unsafe { Self::from_chunks(self.name(), chunks) } + unsafe { Self::from_chunks(self.name().clone(), chunks) } } fn apply_kernel_cast(&self, f: &dyn Fn(&BooleanArray) -> ArrayRef) -> ChunkedArray @@ -413,7 +413,7 @@ impl ChunkApplyKernel for BooleanChunked { S: PolarsDataType, { let chunks = self.downcast_iter().map(f).collect(); - unsafe { ChunkedArray::::from_chunks(self.name(), chunks) } + unsafe { ChunkedArray::::from_chunks(self.name().clone(), chunks) } } } @@ -432,7 +432,7 @@ where S: PolarsDataType, { let chunks = self.downcast_iter().map(f).collect(); - unsafe { ChunkedArray::from_chunks(self.name(), chunks) } + unsafe { ChunkedArray::from_chunks(self.name().clone(), chunks) } } } @@ -446,7 +446,7 @@ impl ChunkApplyKernel for StringChunked { S: PolarsDataType, { let chunks = self.downcast_iter().map(f).collect(); - unsafe { ChunkedArray::from_chunks(self.name(), chunks) } + unsafe { ChunkedArray::from_chunks(self.name().clone(), chunks) } } } @@ -460,7 +460,7 @@ impl ChunkApplyKernel for BinaryChunked { S: PolarsDataType, { let chunks = self.downcast_iter().map(f).collect(); - unsafe { ChunkedArray::from_chunks(self.name(), chunks) } + unsafe { ChunkedArray::from_chunks(self.name().clone(), chunks) } } } @@ -519,7 +519,8 @@ impl<'a> ChunkApply<'a, Series> for ListChunked { let mut idx = 0; self.downcast_iter().for_each(|arr| { arr.iter().for_each(|opt_val| { - let opt_val = opt_val.map(|arrayref| Series::try_from(("", arrayref)).unwrap()); + let opt_val = opt_val + .map(|arrayref| Series::try_from((PlSmallStr::EMPTY, arrayref)).unwrap()); // SAFETY: // length asserted above @@ -543,7 +544,7 @@ where F: Fn(&'a T) -> T + Copy, { let mut ca: ObjectChunked = self.into_iter().map(|opt_v| opt_v.map(f)).collect(); - ca.rename(self.name()); + ca.rename(self.name().clone()); ca } @@ -552,7 +553,7 @@ where F: Fn(Option<&'a T>) -> Option + Copy, { let mut ca: ObjectChunked = self.into_iter().map(f).collect(); - ca.rename(self.name()); + ca.rename(self.name().clone()); ca } diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index c69ae14a5866..774b6fba6755 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -3,6 +3,7 @@ use std::error::Error; use arrow::array::{Array, MutablePlString, StaticArray}; use arrow::compute::utils::combine_validities_and; use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; use crate::chunked_array::metadata::MetadataProperties; use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter}; @@ -49,7 +50,7 @@ where F: FnMut(&T::Array) -> Arr, { let iter = ca.downcast_iter().map(op); - ChunkedArray::from_chunk_iter(ca.name(), iter) + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) } /// Applies a kernel that produces `Array` types. @@ -61,9 +62,9 @@ where Arr: Array, F: FnMut(T::Array) -> Arr, { - let name = ca.name().to_owned(); + let name = ca.name().clone(); let iter = ca.downcast_into_iter().map(op); - ChunkedArray::from_chunk_iter(&name, iter) + ChunkedArray::from_chunk_iter(name, iter) } #[inline] @@ -78,12 +79,12 @@ where let iter = ca .downcast_iter() .map(|arr| arr.iter().map(&mut op).collect_arr()); - ChunkedArray::from_chunk_iter(ca.name(), iter) + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) } else { let iter = ca .downcast_iter() .map(|arr| arr.values_iter().map(|x| op(Some(x))).collect_arr()); - ChunkedArray::from_chunk_iter(ca.name(), iter) + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) } } @@ -101,7 +102,7 @@ where let iter = ca .downcast_iter() .map(|arr| arr.iter().map(&mut op).try_collect_arr()); - ChunkedArray::try_from_chunk_iter(ca.name(), iter) + ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter) } #[inline] @@ -114,7 +115,7 @@ where { if ca.null_count() == ca.len() { let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow(CompatLevel::newest())); - return ChunkedArray::with_chunk(ca.name(), arr); + return ChunkedArray::with_chunk(ca.name().clone(), arr); } let iter = ca.downcast_iter().map(|arr| { @@ -122,7 +123,7 @@ where let arr: V::Array = arr.values_iter().map(&mut op).collect_arr(); arr.with_validity_typed(validity) }); - ChunkedArray::from_chunk_iter(ca.name(), iter) + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) } #[inline] @@ -138,7 +139,7 @@ where { if ca.null_count() == ca.len() { let arr = V::Array::full_null(ca.len(), V::get_dtype().to_arrow(CompatLevel::newest())); - return Ok(ChunkedArray::with_chunk(ca.name(), arr)); + return Ok(ChunkedArray::with_chunk(ca.name().clone(), arr)); } let iter = ca.downcast_iter().map(|arr| { @@ -146,7 +147,7 @@ where let arr: V::Array = arr.values_iter().map(&mut op).try_collect_arr()?; Ok(arr.with_validity_typed(validity)) }); - ChunkedArray::try_from_chunk_iter(ca.name(), iter) + ChunkedArray::try_from_chunk_iter(ca.name().clone(), iter) } /// Applies a kernel that produces `Array` types. @@ -164,7 +165,7 @@ where let iter = ca .downcast_iter() .map(|arr| op(arr).with_validity_typed(arr.validity().cloned())); - ChunkedArray::from_chunk_iter(ca.name(), iter) + ChunkedArray::from_chunk_iter(ca.name().clone(), iter) } /// Applies a kernel that produces `Array` types. @@ -176,7 +177,7 @@ where Arr: Array + StaticArray, F: FnMut(&T::Array) -> Arr, { - ChunkedArray::from_chunk_iter(ca.name(), ca.downcast_iter().map(op)) + ChunkedArray::from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op)) } #[inline] @@ -191,7 +192,7 @@ where F: FnMut(&T::Array) -> Result, E: Error, { - ChunkedArray::try_from_chunk_iter(ca.name(), ca.downcast_iter().map(op)) + ChunkedArray::try_from_chunk_iter(ca.name().clone(), ca.downcast_iter().map(op)) } #[inline] @@ -220,7 +221,7 @@ where .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); element_iter.collect_arr() }); - ChunkedArray::from_chunk_iter(lhs.name(), iter) + ChunkedArray::from_chunk_iter(lhs.name().clone(), iter) } #[inline] @@ -297,7 +298,7 @@ where .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); element_iter.try_collect_arr() }); - ChunkedArray::try_from_chunk_iter(lhs.name(), iter) + ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter) } #[inline] @@ -317,7 +318,7 @@ where let len = lhs.len().min(rhs.len()); let arr = V::Array::full_null(len, V::get_dtype().to_arrow(CompatLevel::newest())); - return ChunkedArray::with_chunk(lhs.name(), arr); + return ChunkedArray::with_chunk(lhs.name().clone(), arr); } let (lhs, rhs) = align_chunks_binary(lhs, rhs); @@ -336,7 +337,7 @@ where let array: V::Array = element_iter.collect_arr(); array.with_validity_typed(validity) }); - ChunkedArray::from_chunk_iter(lhs.name(), iter) + ChunkedArray::from_chunk_iter(lhs.name().clone(), iter) } /// Apply elementwise binary function which produces string, amortising allocations. @@ -373,7 +374,7 @@ where }); mutarr.freeze() }); - ChunkedArray::from_chunk_iter(lhs.name(), iter) + ChunkedArray::from_chunk_iter(lhs.name().clone(), iter) } /// Applies a kernel that produces `Array` types. @@ -385,7 +386,7 @@ pub fn binary_mut_values( lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F, - name: &str, + name: PlSmallStr, ) -> ChunkedArray where T: PolarsDataType, @@ -413,7 +414,7 @@ pub fn binary_mut_with_options( lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F, - name: &str, + name: PlSmallStr, ) -> ChunkedArray where T: PolarsDataType, @@ -435,7 +436,7 @@ pub fn try_binary_mut_with_options( lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F, - name: &str, + name: PlSmallStr, ) -> Result, E> where T: PolarsDataType, @@ -466,7 +467,7 @@ where Arr: Array, F: FnMut(&T::Array, &U::Array) -> Arr, { - binary_mut_with_options(lhs, rhs, op, lhs.name()) + binary_mut_with_options(lhs, rhs, op, lhs.name().clone()) } /// Applies a kernel that produces `Array` types. @@ -482,13 +483,13 @@ where Arr: Array, F: FnMut(L::Array, R::Array) -> Arr, { - let name = lhs.name().to_owned(); + let name = lhs.name().clone(); let (lhs, rhs) = align_chunks_binary_owned(lhs, rhs); let iter = lhs .downcast_into_iter() .zip(rhs.downcast_into_iter()) .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); - ChunkedArray::from_chunk_iter(&name, iter) + ChunkedArray::from_chunk_iter(name, iter) } /// Applies a kernel that produces `Array` types. @@ -510,7 +511,7 @@ where .downcast_iter() .zip(rhs.downcast_iter()) .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)); - ChunkedArray::try_from_chunk_iter(lhs.name(), iter) + ChunkedArray::try_from_chunk_iter(lhs.name().clone(), iter) } /// Applies a kernel that produces `ArrayRef` of the same type. @@ -566,7 +567,7 @@ where .zip(rhs.downcast_iter()) .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr)) .collect::>(); - Series::try_from((lhs.name(), chunks)) + Series::try_from((lhs.name().clone(), chunks)) } /// Applies a kernel that produces `ArrayRef` of the same type. @@ -636,7 +637,7 @@ where ); element_iter.try_collect_arr() }); - ChunkedArray::try_from_chunk_iter(ca1.name(), iter) + ChunkedArray::try_from_chunk_iter(ca1.name().clone(), iter) } #[inline] @@ -677,7 +678,7 @@ where ); element_iter.collect_arr() }); - ChunkedArray::from_chunk_iter(ca1.name(), iter) + ChunkedArray::from_chunk_iter(ca1.name().clone(), iter) } pub fn broadcast_binary_elementwise( @@ -697,7 +698,7 @@ where match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.get_unchecked(0) }; - unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name()) + unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone()) }, (_, 1) => { let b = unsafe { rhs.get_unchecked(0) }; @@ -722,7 +723,7 @@ where match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.get_unchecked(0) }; - Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name())) + Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name().clone())) }, (_, 1) => { let b = unsafe { rhs.get_unchecked(0) }; @@ -750,13 +751,13 @@ where let len = if min == 1 { max } else { min }; let arr = V::Array::full_null(len, V::get_dtype().to_arrow(CompatLevel::newest())); - return ChunkedArray::with_chunk(lhs.name(), arr); + return ChunkedArray::with_chunk(lhs.name().clone(), arr); } match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.value_unchecked(0) }; - unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name()) + unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name().clone()) }, (_, 1) => { let b = unsafe { rhs.value_unchecked(0) }; @@ -793,7 +794,7 @@ where lhs.len(), O::get_dtype().to_arrow(CompatLevel::newest()), ); - ChunkedArray::::with_chunk(lhs.name(), arr) + ChunkedArray::::with_chunk(lhs.name().clone(), arr) }, Some(rhs) => unary_kernel(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())), } @@ -806,14 +807,14 @@ where rhs.len(), O::get_dtype().to_arrow(CompatLevel::newest()), ); - ChunkedArray::::with_chunk(lhs.name(), arr) + ChunkedArray::::with_chunk(lhs.name().clone(), arr) }, Some(lhs) => unary_kernel(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)), } }, _ => panic!("Cannot apply operation on arrays of different lengths"), }; - out.with_name(name) + out.with_name(name.clone()) } pub fn apply_binary_kernel_broadcast_owned( @@ -843,7 +844,7 @@ where lhs.len(), O::get_dtype().to_arrow(CompatLevel::newest()), ); - ChunkedArray::::with_chunk(lhs.name(), arr) + ChunkedArray::::with_chunk(lhs.name().clone(), arr) }, Some(rhs) => unary_kernel_owned(lhs, |arr| rhs_broadcast_kernel(arr, rhs.clone())), } @@ -856,12 +857,12 @@ where rhs.len(), O::get_dtype().to_arrow(CompatLevel::newest()), ); - ChunkedArray::::with_chunk(lhs.name(), arr) + ChunkedArray::::with_chunk(lhs.name().clone(), arr) }, Some(lhs) => unary_kernel_owned(rhs, |arr| lhs_broadcast_kernel(lhs.clone(), arr)), } }, _ => panic!("Cannot apply operation on arrays of different lengths"), }; - out.with_name(&name) + out.with_name(name) } diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs index 9a2f1c33594a..7b20d77e2444 100644 --- a/crates/polars-core/src/chunked_array/ops/bit_repr.rs +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -20,7 +20,7 @@ fn reinterpret_chunked_array( PrimitiveArray::from_data_default(reinterpreted_buf, array.validity().cloned()) }); - ChunkedArray::from_chunk_iter(ca.name(), chunks) + ChunkedArray::from_chunk_iter(ca.name().clone(), chunks) } /// Reinterprets the type of a [`ListChunked`]. T and U must have the same size @@ -53,7 +53,7 @@ fn reinterpret_list_chunked( ) }); - ListChunked::from_chunk_iter(ca.name(), chunks) + ListChunked::from_chunk_iter(ca.name().clone(), chunks) } #[cfg(all(feature = "reinterpret", feature = "dtype-i16", feature = "dtype-u16"))] diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index d97af95367e6..c221505ef936 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -363,7 +363,7 @@ impl ObjectChunked { if self.chunks.len() == 1 { self.clone() } else { - let mut builder = ObjectChunkedBuilder::new(self.name(), self.len()); + let mut builder = ObjectChunkedBuilder::new(self.name().clone(), self.len()); let chunks = self.downcast_iter(); // todo! use iterators once implemented @@ -398,7 +398,7 @@ mod test { #[test] #[cfg(feature = "dtype-categorical")] fn test_categorical_map_after_rechunk() { - let s = Series::new("", &["foo", "bar", "spam"]); + let s = Series::new(PlSmallStr::EMPTY, &["foo", "bar", "spam"]); let mut a = s .cast(&DataType::Categorical(None, Default::default())) .unwrap(); diff --git a/crates/polars-core/src/chunked_array/ops/decimal.rs b/crates/polars-core/src/chunked_array/ops/decimal.rs index e2f9c5845429..5f242ee37caa 100644 --- a/crates/polars-core/src/chunked_array/ops/decimal.rs +++ b/crates/polars-core/src/chunked_array/ops/decimal.rs @@ -43,7 +43,7 @@ mod test { "5.104", "5.25251525353", ]; - let s = StringChunked::from_slice("test", &vals); + let s = StringChunked::from_slice(PlSmallStr::from_str("test"), &vals); let s = s.to_decimal(6).unwrap(); assert_eq!(s.dtype(), &DataType::Decimal(None, Some(5))); assert_eq!(s.len(), 7); diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index b44ee0863a98..2d1bddb2f4e3 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -1,16 +1,9 @@ use arrow::array::*; use arrow::bitmap::utils::set_bit_unchecked; use arrow::bitmap::{Bitmap, MutableBitmap}; -use arrow::legacy::array::list::AnonymousBuilder; -#[cfg(feature = "dtype-array")] -use arrow::legacy::is_valid::IsValid; use arrow::legacy::prelude::*; -use arrow::legacy::trusted_len::TrustedLenPush; use polars_utils::slice::GetSaferUnchecked; -#[cfg(feature = "dtype-array")] -use crate::chunked_array::builder::get_fixed_size_list_builder; -use crate::chunked_array::metadata::MetadataProperties; use crate::prelude::*; use crate::series::implementations::null::NullChunked; @@ -154,18 +147,24 @@ where new_values.into(), Some(validity.into()), ); - Series::try_from((self.name(), Box::new(arr) as ArrayRef)).unwrap() + Series::try_from((self.name().clone(), Box::new(arr) as ArrayRef)).unwrap() } } impl ExplodeByOffsets for Float32Chunked { fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.apply_as_ints(|s| s.explode_by_offsets(offsets)) + self.apply_as_ints(|s| { + let ca = s.u32().unwrap(); + ca.explode_by_offsets(offsets) + }) } } impl ExplodeByOffsets for Float64Chunked { fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.apply_as_ints(|s| s.explode_by_offsets(offsets)) + self.apply_as_ints(|s| { + let ca = s.u64().unwrap(); + ca.explode_by_offsets(offsets) + }) } } @@ -190,7 +189,7 @@ impl ExplodeByOffsets for BooleanChunked { let arr = self.downcast_iter().next().unwrap(); let cap = get_capacity(offsets); - let mut builder = BooleanChunkedBuilder::new(self.name(), cap); + let mut builder = BooleanChunkedBuilder::new(self.name().clone(), cap); let mut start = offsets[0] as usize; let mut last = start; @@ -225,166 +224,6 @@ impl ExplodeByOffsets for BooleanChunked { } } -impl ExplodeByOffsets for ListChunked { - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - debug_assert_eq!(self.chunks.len(), 1); - let arr = self.downcast_iter().next().unwrap(); - - let cap = get_capacity(offsets); - let inner_type = self.inner_dtype(); - - let mut builder = arrow::legacy::array::list::AnonymousBuilder::new(cap); - let mut owned = Vec::with_capacity(cap); - let mut start = offsets[0] as usize; - let mut last = start; - - let mut process_range = |start: usize, last: usize, builder: &mut AnonymousBuilder<'_>| { - let vals = arr.slice_typed(start, last - start); - for opt_arr in vals.into_iter() { - match opt_arr { - None => builder.push_null(), - Some(arr) => { - unsafe { - // we create a pointer to evade the bck - let ptr = arr.as_ref() as *const dyn Array; - // SAFETY: we preallocated - owned.push_unchecked(arr); - // SAFETY: the pointer is still valid as `owned` will not reallocate - builder.push(&*ptr as &dyn Array); - } - }, - } - } - }; - - for &o in &offsets[1..] { - let o = o as usize; - if o == last { - if start != last { - process_range(start, last, &mut builder); - } - builder.push_null(); - start = o; - } - last = o; - } - process_range(start, last, &mut builder); - let arr = builder - .finish(Some(&inner_type.to_arrow(CompatLevel::newest()))) - .unwrap(); - let mut ca = unsafe { self.copy_with_chunks(vec![Box::new(arr)]) }; - - use MetadataProperties as P; - ca.copy_metadata(self, P::SORTED | P::FAST_EXPLODE_LIST); - - ca.into_series() - } -} - -#[cfg(feature = "dtype-array")] -impl ExplodeByOffsets for ArrayChunked { - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - debug_assert_eq!(self.chunks.len(), 1); - let arr = self.downcast_iter().next().unwrap(); - - let cap = get_capacity(offsets); - let inner_type = self.inner_dtype(); - let mut builder = - get_fixed_size_list_builder(inner_type, cap, self.width(), self.name()).unwrap(); - - let mut start = offsets[0] as usize; - let mut last = start; - for &o in &offsets[1..] { - let o = o as usize; - if o == last { - if start != last { - let array = arr.slice_typed(start, last - start); - let values = array.values().as_ref(); - - for i in 0..array.len() { - unsafe { - if array.is_valid_unchecked(i) { - builder.push_unchecked(values, i) - } else { - builder.push_null() - } - } - } - } - unsafe { - builder.push_null(); - } - start = o; - } - last = o; - } - let array = arr.slice_typed(start, last - start); - let values = array.values().as_ref(); - for i in 0..array.len() { - unsafe { - if array.is_valid_unchecked(i) { - builder.push_unchecked(values, i) - } else { - builder.push_null() - } - } - } - - builder.finish().into() - } -} - -impl ExplodeByOffsets for StringChunked { - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - unsafe { - self.as_binary() - .explode_by_offsets(offsets) - .cast_unchecked(&DataType::String) - .unwrap() - } - } -} - -impl ExplodeByOffsets for BinaryChunked { - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - debug_assert_eq!(self.chunks.len(), 1); - let arr = self.downcast_iter().next().unwrap(); - - let cap = get_capacity(offsets); - let mut builder = BinaryChunkedBuilder::new(self.name(), cap); - - let mut start = offsets[0] as usize; - let mut last = start; - for &o in &offsets[1..] { - let o = o as usize; - if o == last { - if start != last { - let vals = arr.slice_typed(start, last - start); - if vals.null_count() == 0 { - builder - .chunk_builder - .extend_trusted_len_values(vals.values_iter()) - } else { - builder.chunk_builder.extend_trusted_len(vals.into_iter()); - } - } - builder.append_null(); - start = o; - } - last = o; - } - let vals = arr.slice_typed(start, last - start); - if vals.null_count() == 0 { - builder - .chunk_builder - .extend_trusted_len_values(vals.values_iter()) - } else { - builder.chunk_builder.extend_trusted_len(vals.into_iter()); - } - builder.finish().into() - } -} - /// Convert Arrow array offsets to indexes of the original list pub(crate) fn offsets_to_indexes(offsets: &[i64], capacity: usize) -> Vec { if offsets.is_empty() { @@ -430,13 +269,17 @@ mod test { #[test] fn test_explode_list() -> PolarsResult<()> { - let mut builder = get_list_builder(&DataType::Int32, 5, 5, "a")?; + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a"))?; builder - .append_series(&Series::new("", &[1, 2, 3, 3])) + .append_series(&Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 3])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1])) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[2])) .unwrap(); - builder.append_series(&Series::new("", &[1])).unwrap(); - builder.append_series(&Series::new("", &[2])).unwrap(); let ca = builder.finish(); assert!(ca._can_fast_explode()); @@ -454,41 +297,19 @@ mod test { Ok(()) } - #[test] - fn test_explode_list_nulls() -> PolarsResult<()> { - let ca = Int32Chunked::from_slice_options("", &[None, Some(1), Some(2)]); - let offsets = &[0, 3, 3]; - let out = ca.explode_by_offsets(offsets); - assert_eq!( - Vec::from(out.i32().unwrap()), - &[None, Some(1), Some(2), None] - ); - - let ca = BooleanChunked::from_slice_options("", &[None, Some(true), Some(false)]); - let out = ca.explode_by_offsets(offsets); - assert_eq!( - Vec::from(out.bool().unwrap()), - &[None, Some(true), Some(false), None] - ); - - let ca = StringChunked::from_slice_options("", &[None, Some("b"), Some("c")]); - let out = ca.explode_by_offsets(offsets); - assert_eq!( - Vec::from(out.str().unwrap()), - &[None, Some("b"), Some("c"), None] - ); - Ok(()) - } - #[test] fn test_explode_empty_list_slot() -> PolarsResult<()> { // primitive - let mut builder = get_list_builder(&DataType::Int32, 5, 5, "a")?; - builder.append_series(&Series::new("", &[1i32, 2])).unwrap(); + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a"))?; + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1i32, 2])) + .unwrap(); builder - .append_series(&Int32Chunked::from_slice("", &[]).into_series()) + .append_series(&Int32Chunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[3i32])) .unwrap(); - builder.append_series(&Series::new("", &[3i32])).unwrap(); let ca = builder.finish(); let exploded = ca.explode()?; @@ -498,16 +319,22 @@ mod test { ); // more primitive - let mut builder = get_list_builder(&DataType::Int32, 5, 5, "a")?; - builder.append_series(&Series::new("", &[1i32])).unwrap(); + let mut builder = get_list_builder(&DataType::Int32, 5, 5, PlSmallStr::from_static("a"))?; + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[1i32])) + .unwrap(); builder - .append_series(&Int32Chunked::from_slice("", &[]).into_series()) + .append_series(&Int32Chunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) .unwrap(); - builder.append_series(&Series::new("", &[2i32])).unwrap(); builder - .append_series(&Int32Chunked::from_slice("", &[]).into_series()) + .append_series(&Series::new(PlSmallStr::EMPTY, &[2i32])) + .unwrap(); + builder + .append_series(&Int32Chunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[3, 4i32])) .unwrap(); - builder.append_series(&Series::new("", &[3, 4i32])).unwrap(); let ca = builder.finish(); let exploded = ca.explode()?; @@ -517,26 +344,41 @@ mod test { ); // string - let mut builder = get_list_builder(&DataType::String, 5, 5, "a")?; - builder.append_series(&Series::new("", &["abc"])).unwrap(); + let mut builder = get_list_builder(&DataType::String, 5, 5, PlSmallStr::from_static("a"))?; + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &["abc"])) + .unwrap(); builder .append_series( - &>::from_slice("", &[]) - .into_series(), + &>::from_slice( + PlSmallStr::EMPTY, + &[], + ) + .into_series(), ) .unwrap(); - builder.append_series(&Series::new("", &["de"])).unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &["de"])) + .unwrap(); builder .append_series( - &>::from_slice("", &[]) - .into_series(), + &>::from_slice( + PlSmallStr::EMPTY, + &[], + ) + .into_series(), ) .unwrap(); - builder.append_series(&Series::new("", &["fg"])).unwrap(); + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &["fg"])) + .unwrap(); builder .append_series( - &>::from_slice("", &[]) - .into_series(), + &>::from_slice( + PlSmallStr::EMPTY, + &[], + ) + .into_series(), ) .unwrap(); @@ -548,17 +390,21 @@ mod test { ); // boolean - let mut builder = get_list_builder(&DataType::Boolean, 5, 5, "a")?; - builder.append_series(&Series::new("", &[true])).unwrap(); + let mut builder = get_list_builder(&DataType::Boolean, 5, 5, PlSmallStr::from_static("a"))?; + builder + .append_series(&Series::new(PlSmallStr::EMPTY, &[true])) + .unwrap(); + builder + .append_series(&BooleanChunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) + .unwrap(); builder - .append_series(&BooleanChunked::from_slice("", &[]).into_series()) + .append_series(&Series::new(PlSmallStr::EMPTY, &[false])) .unwrap(); - builder.append_series(&Series::new("", &[false])).unwrap(); builder - .append_series(&BooleanChunked::from_slice("", &[]).into_series()) + .append_series(&BooleanChunked::from_slice(PlSmallStr::EMPTY, &[]).into_series()) .unwrap(); builder - .append_series(&Series::new("", &[true, true])) + .append_series(&Series::new(PlSmallStr::EMPTY, &[true, true])) .unwrap(); let ca = builder.finish(); diff --git a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs index f335e3074665..8b1b87cbdaf8 100644 --- a/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs +++ b/crates/polars-core/src/chunked_array/ops/explode_and_offsets.rs @@ -3,6 +3,51 @@ use arrow::offset::OffsetsBuffer; use super::*; +impl ListChunked { + fn specialized( + &self, + values: ArrayRef, + offsets: &[i64], + offsets_buf: OffsetsBuffer, + ) -> (Series, OffsetsBuffer) { + // SAFETY: inner_dtype should be correct + let values = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![values], + &self.inner_dtype().to_physical(), + ) + }; + + use crate::chunked_array::ops::explode::ExplodeByOffsets; + + let mut values = match values.dtype() { + DataType::Boolean => { + let t = values.bool().unwrap(); + ExplodeByOffsets::explode_by_offsets(t, offsets).into_series() + }, + DataType::Null => { + let t = values.null().unwrap(); + ExplodeByOffsets::explode_by_offsets(t, offsets).into_series() + }, + dtype => { + with_match_physical_numeric_polars_type!(dtype, |$T| { + let t: &ChunkedArray<$T> = values.as_ref().as_ref(); + ExplodeByOffsets::explode_by_offsets(t, offsets).into_series() + }) + }, + }; + + // let mut values = values.explode_by_offsets(offsets); + // restore logical type + unsafe { + values = values.cast_unchecked(self.inner_dtype()).unwrap(); + } + + (values, offsets_buf) + } +} + impl ChunkExplode for ListChunked { fn offsets(&self) -> PolarsResult> { let ca = self.rechunk(); @@ -40,7 +85,7 @@ impl ChunkExplode for ListChunked { ( unsafe { Series::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![values], &self.inner_dtype().to_physical(), ) @@ -64,16 +109,36 @@ impl ChunkExplode for ListChunked { panic!("could have fast exploded") } } - if listarr.null_count() == 0 { - // SAFETY: inner_dtype should be correct - let values = unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![values], - &self.inner_dtype().to_physical(), - ) - }; - (values.explode_by_offsets(offsets), offsets_buf) + let (indices, new_offsets) = if listarr.null_count() == 0 { + // SPECIALIZED path. + let inner_phys = self.inner_dtype().to_physical(); + if inner_phys.is_numeric() || inner_phys.is_null() || inner_phys.is_bool() { + return Ok(self.specialized(values, offsets, offsets_buf)); + } + // Use gather + let mut indices = + MutablePrimitiveArray::::with_capacity(*offsets_buf.last() as usize); + let mut new_offsets = Vec::with_capacity(listarr.len() + 1); + let mut current_offset = 0i64; + let mut iter = offsets.iter(); + if let Some(mut previous) = iter.next().copied() { + new_offsets.push(current_offset); + iter.for_each(|&offset| { + let len = offset - previous; + let start = previous as IdxSize; + let end = offset as IdxSize; + + if len == 0 { + indices.push_null(); + } else { + indices.extend_trusted_len_values(start..end); + } + current_offset += len; + previous = offset; + new_offsets.push(current_offset); + }) + } + (indices, new_offsets) } else { // we have already ensure that validity is not none. let validity = listarr.validity().unwrap(); @@ -105,20 +170,22 @@ impl ChunkExplode for ListChunked { new_offsets.push(current_offset); }) } - // SAFETY: the indices we generate are in bounds - let chunk = unsafe { take_unchecked(values.as_ref(), &indices.into()) }; - // SAFETY: inner_dtype should be correct - let s = unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![chunk], - &self.inner_dtype().to_physical(), - ) - }; - // SAFETY: monotonically increasing - let new_offsets = unsafe { OffsetsBuffer::new_unchecked(new_offsets.into()) }; - (s, new_offsets) - } + (indices, new_offsets) + }; + + // SAFETY: the indices we generate are in bounds + let chunk = unsafe { take_unchecked(values.as_ref(), &indices.into()) }; + // SAFETY: inner_dtype should be correct + let s = unsafe { + Series::from_chunks_and_dtype_unchecked( + self.name().clone(), + vec![chunk], + &self.inner_dtype().to_physical(), + ) + }; + // SAFETY: monotonically increasing + let new_offsets = unsafe { OffsetsBuffer::new_unchecked(new_offsets.into()) }; + (s, new_offsets) }; debug_assert_eq!(s.name(), self.name()); // restore logical type @@ -177,7 +244,7 @@ impl ChunkExplode for ArrayChunked { let arr = ca.downcast_iter().next().unwrap(); // fast-path for non-null array. if arr.null_count() == 0 { - let s = Series::try_from((self.name(), arr.values().clone())) + let s = Series::try_from((self.name().clone(), arr.values().clone())) .unwrap() .cast(ca.inner_dtype())?; let width = self.width() as i64; @@ -224,7 +291,11 @@ impl ChunkExplode for ArrayChunked { Ok(( // SAFETY: inner_dtype should be correct unsafe { - Series::from_chunks_and_dtype_unchecked(ca.name(), vec![chunk], ca.inner_dtype()) + Series::from_chunks_and_dtype_unchecked( + ca.name().clone(), + vec![chunk], + ca.inner_dtype(), + ) }, offsets, )) diff --git a/crates/polars-core/src/chunked_array/ops/extend.rs b/crates/polars-core/src/chunked_array/ops/extend.rs index ccae7d6e18bd..9489c425d3ff 100644 --- a/crates/polars-core/src/chunked_array/ops/extend.rs +++ b/crates/polars-core/src/chunked_array/ops/extend.rs @@ -36,13 +36,13 @@ where /// Prefer `append` over `extend` when you want to append many times before doing a query. For instance /// when you read in multiple files and when to store them in a single `DataFrame`. /// In the latter case finish the sequence of `append` operations with a [`rechunk`](Self::rechunk). - pub fn extend(&mut self, other: &Self) { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { update_sorted_flag_before_append::(self, other); // all to a single chunk if self.chunks.len() > 1 { - self.append(other); + self.append(other)?; *self = self.rechunk(); - return; + return Ok(()); } // Depending on the state of the underlying arrow array we // might be able to get a `MutablePrimitiveArray` @@ -84,12 +84,13 @@ where } } self.compute_len(); + Ok(()) } } #[doc(hidden)] impl StringChunked { - pub fn extend(&mut self, other: &Self) { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { self.set_sorted_flag(IsSorted::Not); self.append(other) } @@ -97,7 +98,7 @@ impl StringChunked { #[doc(hidden)] impl BinaryChunked { - pub fn extend(&mut self, other: &Self) { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { self.set_sorted_flag(IsSorted::Not); self.append(other) } @@ -105,7 +106,7 @@ impl BinaryChunked { #[doc(hidden)] impl BinaryOffsetChunked { - pub fn extend(&mut self, other: &Self) { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { self.set_sorted_flag(IsSorted::Not); self.append(other) } @@ -113,13 +114,13 @@ impl BinaryOffsetChunked { #[doc(hidden)] impl BooleanChunked { - pub fn extend(&mut self, other: &Self) { + pub fn extend(&mut self, other: &Self) -> PolarsResult<()> { update_sorted_flag_before_append::(self, other); // make sure that we are a single chunk already if self.chunks.len() > 1 { - self.append(other); + self.append(other)?; *self = self.rechunk(); - return; + return Ok(()); } let arr = self.downcast_iter().next().unwrap(); @@ -148,6 +149,7 @@ impl BooleanChunked { } self.compute_len(); self.set_sorted_flag(IsSorted::Not); + Ok(()) } } @@ -189,48 +191,54 @@ mod test { #[test] #[allow(clippy::redundant_clone)] - fn test_extend_primitive() { + fn test_extend_primitive() -> PolarsResult<()> { // create a vec with overcapacity, so that we do not trigger a realloc // this allows us to test if the mutation was successful let mut values = Vec::with_capacity(32); values.extend_from_slice(&[1, 2, 3]); - let mut ca = Int32Chunked::from_vec("a", values); + let mut ca = Int32Chunked::from_vec(PlSmallStr::from_static("a"), values); let location = ca.cont_slice().unwrap().as_ptr() as usize; - let to_append = Int32Chunked::new("a", &[4, 5, 6]); + let to_append = Int32Chunked::new(PlSmallStr::from_static("a"), &[4, 5, 6]); - ca.extend(&to_append); + ca.extend(&to_append)?; let location2 = ca.cont_slice().unwrap().as_ptr() as usize; assert_eq!(location, location2); assert_eq!(ca.cont_slice().unwrap(), [1, 2, 3, 4, 5, 6]); // now check if it succeeds if we cannot do this with a mutable. let _temp = ca.chunks.clone(); - ca.extend(&to_append); + ca.extend(&to_append)?; let location2 = ca.cont_slice().unwrap().as_ptr() as usize; assert_ne!(location, location2); assert_eq!(ca.cont_slice().unwrap(), [1, 2, 3, 4, 5, 6, 4, 5, 6]); + + Ok(()) } #[test] - fn test_extend_string() { - let mut ca = StringChunked::new("a", &["a", "b", "c"]); - let to_append = StringChunked::new("a", &["a", "b", "e"]); + fn test_extend_string() -> PolarsResult<()> { + let mut ca = StringChunked::new(PlSmallStr::from_static("a"), &["a", "b", "c"]); + let to_append = StringChunked::new(PlSmallStr::from_static("a"), &["a", "b", "e"]); - ca.extend(&to_append); + ca.extend(&to_append)?; assert_eq!(ca.len(), 6); let vals = ca.into_no_null_iter().collect::>(); - assert_eq!(vals, ["a", "b", "c", "a", "b", "e"]) + assert_eq!(vals, ["a", "b", "c", "a", "b", "e"]); + + Ok(()) } #[test] - fn test_extend_bool() { - let mut ca = BooleanChunked::new("a", [true, false]); - let to_append = BooleanChunked::new("a", &[false, false]); + fn test_extend_bool() -> PolarsResult<()> { + let mut ca = BooleanChunked::new(PlSmallStr::from_static("a"), [true, false]); + let to_append = BooleanChunked::new(PlSmallStr::from_static("a"), &[false, false]); - ca.extend(&to_append); + ca.extend(&to_append)?; assert_eq!(ca.len(), 4); let vals = ca.into_no_null_iter().collect::>(); assert_eq!(vals, [true, false, false, false]); + + Ok(()) } } diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 373311eda9b4..7aa348d5e440 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -2,7 +2,7 @@ use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::legacy::kernels::set::set_at_nulls; use bytemuck::Zeroable; use num_traits::{Bounded, NumCast, One, Zero}; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use crate::prelude::*; @@ -30,7 +30,7 @@ impl Series { /// ```rust /// # use polars_core::prelude::*; /// fn example() -> PolarsResult<()> { - /// let s = Series::new("some_missing", &[Some(1), None, Some(2)]); + /// let s = Series::new("some_missing".into(), &[Some(1), None, Some(2)]); /// /// let filled = s.fill_null(FillNullStrategy::Forward(None))?; /// assert_eq!(Vec::from(filled.i32()?), &[Some(1), Some(1), Some(2)]); @@ -219,7 +219,7 @@ where FillNullStrategy::Forward(_) => unreachable!(), FillNullStrategy::Backward(_) => unreachable!(), }; - out.rename(ca.name()); + out.rename(ca.name().clone()); Ok(out) } diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs index 2fb493f9e9c3..a927f3c6cd99 100644 --- a/crates/polars-core/src/chunked_array/ops/filter.rs +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -190,12 +190,12 @@ where if filter.len() == 1 { return match filter.get(0) { Some(true) => Ok(self.clone()), - _ => Ok(ObjectChunked::new_empty(self.name())), + _ => Ok(ObjectChunked::new_empty(self.name().clone())), }; } check_filter_len!(self, filter); let chunks = self.downcast_iter().collect::>(); - let mut builder = ObjectChunkedBuilder::::new(self.name(), self.len()); + let mut builder = ObjectChunkedBuilder::::new(self.name().clone(), self.len()); for (idx, mask) in filter.into_iter().enumerate() { if mask.unwrap_or(false) { let (chunk_idx, idx) = self.index_to_chunked_index(idx); diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index fce46375666c..3f797d588e47 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -1,4 +1,4 @@ -use arrow::bitmap::MutableBitmap; +use arrow::bitmap::{Bitmap, MutableBitmap}; use crate::chunked_array::builder::get_list_builder; use crate::prelude::*; @@ -8,7 +8,7 @@ impl ChunkFull for ChunkedArray where T: PolarsNumericType, { - fn full(name: &str, value: T::Native, length: usize) -> Self { + fn full(name: PlSmallStr, value: T::Native, length: usize) -> Self { let data = vec![value; length]; let mut out = ChunkedArray::from_vec(name, data); out.set_sorted_flag(IsSorted::Ascending); @@ -20,13 +20,13 @@ impl ChunkFullNull for ChunkedArray where T: PolarsNumericType, { - fn full_null(name: &str, length: usize) -> Self { + fn full_null(name: PlSmallStr, length: usize) -> Self { let arr = PrimitiveArray::new_null(T::get_dtype().to_arrow(CompatLevel::newest()), length); ChunkedArray::with_chunk(name, arr) } } impl ChunkFull for BooleanChunked { - fn full(name: &str, value: bool, length: usize) -> Self { + fn full(name: PlSmallStr, value: bool, length: usize) -> Self { let mut bits = MutableBitmap::with_capacity(length); bits.extend_constant(length, value); let arr = BooleanArray::from_data_default(bits.into(), None); @@ -37,14 +37,14 @@ impl ChunkFull for BooleanChunked { } impl ChunkFullNull for BooleanChunked { - fn full_null(name: &str, length: usize) -> Self { + fn full_null(name: PlSmallStr, length: usize) -> Self { let arr = BooleanArray::new_null(ArrowDataType::Boolean, length); ChunkedArray::with_chunk(name, arr) } } impl<'a> ChunkFull<&'a str> for StringChunked { - fn full(name: &str, value: &'a str, length: usize) -> Self { + fn full(name: PlSmallStr, value: &'a str, length: usize) -> Self { let mut builder = StringChunkedBuilder::new(name, length); builder.chunk_builder.extend_constant(length, Some(value)); let mut out = builder.finish(); @@ -54,14 +54,14 @@ impl<'a> ChunkFull<&'a str> for StringChunked { } impl ChunkFullNull for StringChunked { - fn full_null(name: &str, length: usize) -> Self { + fn full_null(name: PlSmallStr, length: usize) -> Self { let arr = Utf8ViewArray::new_null(DataType::String.to_arrow(CompatLevel::newest()), length); ChunkedArray::with_chunk(name, arr) } } impl<'a> ChunkFull<&'a [u8]> for BinaryChunked { - fn full(name: &str, value: &'a [u8], length: usize) -> Self { + fn full(name: PlSmallStr, value: &'a [u8], length: usize) -> Self { let mut builder = BinaryChunkedBuilder::new(name, length); builder.chunk_builder.extend_constant(length, Some(value)); let mut out = builder.finish(); @@ -71,7 +71,7 @@ impl<'a> ChunkFull<&'a [u8]> for BinaryChunked { } impl ChunkFullNull for BinaryChunked { - fn full_null(name: &str, length: usize) -> Self { + fn full_null(name: PlSmallStr, length: usize) -> Self { let arr = BinaryViewArray::new_null(DataType::Binary.to_arrow(CompatLevel::newest()), length); ChunkedArray::with_chunk(name, arr) @@ -79,7 +79,7 @@ impl ChunkFullNull for BinaryChunked { } impl<'a> ChunkFull<&'a [u8]> for BinaryOffsetChunked { - fn full(name: &str, value: &'a [u8], length: usize) -> Self { + fn full(name: PlSmallStr, value: &'a [u8], length: usize) -> Self { let mut mutable = MutableBinaryArray::with_capacities(length, length * value.len()); mutable.extend_values(std::iter::repeat(value).take(length)); let arr: BinaryArray = mutable.into(); @@ -90,7 +90,7 @@ impl<'a> ChunkFull<&'a [u8]> for BinaryOffsetChunked { } impl ChunkFullNull for BinaryOffsetChunked { - fn full_null(name: &str, length: usize) -> Self { + fn full_null(name: PlSmallStr, length: usize) -> Self { let arr = BinaryArray::::new_null( DataType::BinaryOffset.to_arrow(CompatLevel::newest()), length, @@ -100,7 +100,7 @@ impl ChunkFullNull for BinaryOffsetChunked { } impl ChunkFull<&Series> for ListChunked { - fn full(name: &str, value: &Series, length: usize) -> ListChunked { + fn full(name: PlSmallStr, value: &Series, length: usize) -> ListChunked { let mut builder = get_list_builder(value.dtype(), value.len() * length, length, name).unwrap(); for _ in 0..length { @@ -111,7 +111,7 @@ impl ChunkFull<&Series> for ListChunked { } impl ChunkFullNull for ListChunked { - fn full_null(name: &str, length: usize) -> ListChunked { + fn full_null(name: PlSmallStr, length: usize) -> ListChunked { ListChunked::full_null_with_dtype(name, length, &DataType::Null) } } @@ -119,7 +119,7 @@ impl ChunkFullNull for ListChunked { #[cfg(feature = "dtype-array")] impl ArrayChunked { pub fn full_null_with_dtype( - name: &str, + name: PlSmallStr, length: usize, inner_dtype: &DataType, width: usize, @@ -127,7 +127,7 @@ impl ArrayChunked { let arr = FixedSizeListArray::new_null( ArrowDataType::FixedSizeList( Box::new(ArrowField::new( - "item", + PlSmallStr::from_static("item"), inner_dtype.to_arrow(CompatLevel::newest()), true, )), @@ -141,12 +141,12 @@ impl ArrayChunked { #[cfg(feature = "dtype-array")] impl ChunkFull<&Series> for ArrayChunked { - fn full(name: &str, value: &Series, length: usize) -> ArrayChunked { + fn full(name: PlSmallStr, value: &Series, length: usize) -> ArrayChunked { let width = value.len(); let dtype = value.dtype(); let arrow_dtype = ArrowDataType::FixedSizeList( Box::new(ArrowField::new( - "item", + PlSmallStr::from_static("item"), dtype.to_arrow(CompatLevel::newest()), true, )), @@ -160,16 +160,20 @@ impl ChunkFull<&Series> for ArrayChunked { #[cfg(feature = "dtype-array")] impl ChunkFullNull for ArrayChunked { - fn full_null(name: &str, length: usize) -> ArrayChunked { + fn full_null(name: PlSmallStr, length: usize) -> ArrayChunked { ArrayChunked::full_null_with_dtype(name, length, &DataType::Null, 0) } } impl ListChunked { - pub fn full_null_with_dtype(name: &str, length: usize, inner_dtype: &DataType) -> ListChunked { + pub fn full_null_with_dtype( + name: PlSmallStr, + length: usize, + inner_dtype: &DataType, + ) -> ListChunked { let arr: ListArray = ListArray::new_null( ArrowDataType::LargeList(Box::new(ArrowField::new( - "item", + PlSmallStr::from_static("item"), inner_dtype.to_physical().to_arrow(CompatLevel::newest()), true, ))), @@ -187,15 +191,17 @@ impl ListChunked { } #[cfg(feature = "dtype-struct")] impl ChunkFullNull for StructChunked { - fn full_null(name: &str, length: usize) -> StructChunked { - let s = vec![Series::new_null("", length)]; - StructChunked::from_series(name, &s).unwrap() + fn full_null(name: PlSmallStr, length: usize) -> StructChunked { + let s = vec![Series::new_null(PlSmallStr::EMPTY, length)]; + StructChunked::from_series(name, &s) + .unwrap() + .with_outer_validity(Some(Bitmap::new_zeroed(length))) } } #[cfg(feature = "object")] impl ChunkFull for ObjectChunked { - fn full(name: &str, value: T, length: usize) -> Self + fn full(name: PlSmallStr, value: T, length: usize) -> Self where Self: Sized, { @@ -207,7 +213,7 @@ impl ChunkFull for ObjectChunked { #[cfg(feature = "object")] impl ChunkFullNull for ObjectChunked { - fn full_null(name: &str, length: usize) -> ObjectChunked { + fn full_null(name: PlSmallStr, length: usize) -> ObjectChunked { let mut ca: Self = (0..length).map(|_| None).collect(); ca.rename(name); ca diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index 00c93053cc1e..cb24305f75f6 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -244,7 +244,7 @@ impl ChunkTakeUnchecked for BinaryChunked { .map(|arr| take_unchecked(arr.as_ref(), indices_arr)) .collect::>(); - let mut out = ChunkedArray::from_chunks(self.name(), chunks); + let mut out = ChunkedArray::from_chunks(self.name().clone(), chunks); let sorted_flag = _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag()); @@ -264,7 +264,7 @@ impl ChunkTakeUnchecked for StringChunked { impl + ?Sized> ChunkTakeUnchecked for BinaryChunked { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &I) -> Self { - let indices = IdxCa::mmap_slice("", indices.as_ref()); + let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); self.take_unchecked(&indices) } } @@ -296,7 +296,7 @@ impl ChunkTakeUnchecked for StructChunked { #[cfg(feature = "dtype-struct")] impl + ?Sized> ChunkTakeUnchecked for StructChunked { unsafe fn take_unchecked(&self, indices: &I) -> Self { - let idx = IdxCa::mmap_slice("", indices.as_ref()); + let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref()); self.take_unchecked(&idx) } } @@ -307,7 +307,7 @@ impl IdxCa { let idx = bytemuck::cast_slice::<_, IdxSize>(idx); let arr = unsafe { arrow::ffi::mmap::slice(idx) }; let arr = arr.with_validity_typed(Some(validity)); - let ca = IdxCa::with_chunk("", arr); + let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr); f(&ca) } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index c3d030447794..b252d23814eb 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -121,6 +121,7 @@ pub trait ChunkTakeUnchecked { } /// Create a `ChunkedArray` with new values by index or by boolean mask. +/// /// Note that these operations clone data. This is however the only way we can modify at mask or /// index level as the underlying Arrow arrays are immutable. pub trait ChunkSet<'a, A, B> { @@ -130,7 +131,7 @@ pub trait ChunkSet<'a, A, B> { /// /// ```rust /// # use polars_core::prelude::*; - /// let ca = UInt32Chunked::new("a", &[1, 2, 3]); + /// let ca = UInt32Chunked::new("a".into(), &[1, 2, 3]); /// let new = ca.scatter_single(vec![0, 1], Some(10)).unwrap(); /// /// assert_eq!(Vec::from(&new), &[Some(10), Some(10), Some(3)]); @@ -149,7 +150,7 @@ pub trait ChunkSet<'a, A, B> { /// /// ```rust /// # use polars_core::prelude::*; - /// let ca = Int32Chunked::new("a", &[1, 2, 3]); + /// let ca = Int32Chunked::new("a".into(), &[1, 2, 3]); /// let new = ca.scatter_with(vec![0, 1], |opt_v| opt_v.map(|v| v - 5)).unwrap(); /// /// assert_eq!(Vec::from(&new), &[Some(-4), Some(-3), Some(3)]); @@ -168,8 +169,8 @@ pub trait ChunkSet<'a, A, B> { /// /// ```rust /// # use polars_core::prelude::*; - /// let ca = Int32Chunked::new("a", &[1, 2, 3]); - /// let mask = BooleanChunked::new("mask", &[false, true, false]); + /// let ca = Int32Chunked::new("a".into(), &[1, 2, 3]); + /// let mask = BooleanChunked::new("mask".into(), &[false, true, false]); /// let new = ca.set(&mask, Some(5)).unwrap(); /// assert_eq!(Vec::from(&new), &[Some(1), Some(5), Some(3)]); /// ``` @@ -181,20 +182,19 @@ pub trait ChunkSet<'a, A, B> { /// Cast `ChunkedArray` to `ChunkedArray` pub trait ChunkCast { /// Cast a [`ChunkedArray`] to [`DataType`] - fn cast(&self, data_type: &DataType) -> PolarsResult { - self.cast_with_options(data_type, CastOptions::NonStrict) + fn cast(&self, dtype: &DataType) -> PolarsResult { + self.cast_with_options(dtype, CastOptions::NonStrict) } /// Cast a [`ChunkedArray`] to [`DataType`] - fn cast_with_options(&self, data_type: &DataType, options: CastOptions) - -> PolarsResult; + fn cast_with_options(&self, dtype: &DataType, options: CastOptions) -> PolarsResult; /// Does not check if the cast is a valid one and may over/underflow /// /// # Safety /// - This doesn't do utf8 validation checking when casting from binary /// - This doesn't do categorical bound checking when casting from UInt32 - unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult; + unsafe fn cast_unchecked(&self, dtype: &DataType) -> PolarsResult; } /// Fastest way to do elementwise operations on a [`ChunkedArray`] when the operation is cheaper than @@ -242,6 +242,8 @@ pub trait ChunkAgg { None } + fn _sum_as_f64(&self) -> f64; + fn min(&self) -> Option { None } @@ -424,13 +426,13 @@ pub trait ChunkFillNullValue { /// Fill a ChunkedArray with one value. pub trait ChunkFull { /// Create a ChunkedArray with a single value. - fn full(name: &str, value: T, length: usize) -> Self + fn full(name: PlSmallStr, value: T, length: usize) -> Self where Self: Sized; } pub trait ChunkFullNull { - fn full_null(_name: &str, _length: usize) -> Self + fn full_null(_name: PlSmallStr, _length: usize) -> Self where Self: Sized; } @@ -447,8 +449,8 @@ pub trait ChunkFilter { /// /// ```rust /// # use polars_core::prelude::*; - /// let array = Int32Chunked::new("array", &[1, 2, 3]); - /// let mask = BooleanChunked::new("mask", &[true, false, true]); + /// let array = Int32Chunked::new("array".into(), &[1, 2, 3]); + /// let mask = BooleanChunked::new("mask".into(), &[true, false, true]); /// /// let filtered = array.filter(&mask).unwrap(); /// assert_eq!(Vec::from(&filtered), [Some(1), Some(3)]) @@ -461,7 +463,7 @@ pub trait ChunkFilter { /// Create a new ChunkedArray filled with values at that index. pub trait ChunkExpandAtIndex { /// Create a new ChunkedArray filled with values at that index. - fn new_from_index(&self, length: usize, index: usize) -> ChunkedArray; + fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray; } macro_rules! impl_chunk_expand { @@ -471,8 +473,8 @@ macro_rules! impl_chunk_expand { } let opt_val = $self.get($index); match opt_val { - Some(val) => ChunkedArray::full($self.name(), val, $length), - None => ChunkedArray::full_null($self.name(), $length), + Some(val) => ChunkedArray::full($self.name().clone(), val, $length), + None => ChunkedArray::full_null($self.name().clone(), $length), } }}; } @@ -525,34 +527,36 @@ impl ChunkExpandAtIndex for ListChunked { let opt_val = self.get_as_series(index); match opt_val { Some(val) => { - let mut ca = ListChunked::full(self.name(), &val, length); + let mut ca = ListChunked::full(self.name().clone(), &val, length); unsafe { ca.to_logical(self.inner_dtype().clone()) }; ca }, - None => ListChunked::full_null_with_dtype(self.name(), length, self.inner_dtype()), + None => { + ListChunked::full_null_with_dtype(self.name().clone(), length, self.inner_dtype()) + }, } } } #[cfg(feature = "dtype-struct")] impl ChunkExpandAtIndex for StructChunked { - fn new_from_index(&self, length: usize, index: usize) -> ChunkedArray { + fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray { let (chunk_idx, idx) = self.index_to_chunked_index(index); let chunk = self.downcast_chunks().get(chunk_idx).unwrap(); let chunk = if chunk.is_null(idx) { - new_null_array(chunk.data_type().clone(), length) + new_null_array(chunk.dtype().clone(), length) } else { let values = chunk .values() .iter() .map(|arr| { - let s = Series::try_from(("", arr.clone())).unwrap(); + let s = Series::try_from((PlSmallStr::EMPTY, arr.clone())).unwrap(); let s = s.new_from_index(idx, length); s.chunks()[0].clone() }) .collect::>(); - StructArray::new(chunk.data_type().clone(), values, None).boxed() + StructArray::new(chunk.dtype().clone(), values, None).boxed() }; // SAFETY: chunks are from self. @@ -566,12 +570,12 @@ impl ChunkExpandAtIndex for ArrayChunked { let opt_val = self.get_as_series(index); match opt_val { Some(val) => { - let mut ca = ArrayChunked::full(self.name(), &val, length); + let mut ca = ArrayChunked::full(self.name().clone(), &val, length); unsafe { ca.to_logical(self.inner_dtype().clone()) }; ca }, None => ArrayChunked::full_null_with_dtype( - self.name(), + self.name().clone(), length, self.inner_dtype(), self.width(), @@ -585,8 +589,8 @@ impl ChunkExpandAtIndex> for ObjectChunked { fn new_from_index(&self, index: usize, length: usize) -> ObjectChunked { let opt_val = self.get(index); match opt_val { - Some(val) => ObjectChunked::::full(self.name(), val.clone(), length), - None => ObjectChunked::::full_null(self.name(), length), + Some(val) => ObjectChunked::::full(self.name().clone(), val.clone(), length), + None => ObjectChunked::::full_null(self.name().clone(), length), } } } diff --git a/crates/polars-core/src/chunked_array/ops/nulls.rs b/crates/polars-core/src/chunked_array/ops/nulls.rs index c0ba435c4a51..1d1640055a72 100644 --- a/crates/polars-core/src/chunked_array/ops/nulls.rs +++ b/crates/polars-core/src/chunked_array/ops/nulls.rs @@ -7,19 +7,19 @@ impl ChunkedArray { /// Get a mask of the null values. pub fn is_null(&self) -> BooleanChunked { if !self.has_nulls() { - return BooleanChunked::full(self.name(), false, self.len()); + return BooleanChunked::full(self.name().clone(), false, self.len()); } // dispatch to non-generic function - is_null(self.name(), &self.chunks) + is_null(self.name().clone(), &self.chunks) } /// Get a mask of the valid values. pub fn is_not_null(&self) -> BooleanChunked { if self.null_count() == 0 { - return BooleanChunked::full(self.name(), true, self.len()); + return BooleanChunked::full(self.name().clone(), true, self.len()); } // dispatch to non-generic function - is_not_null(self.name(), &self.chunks) + is_not_null(self.name().clone(), &self.chunks) } pub(crate) fn coalesce_nulls(&self, other: &[ArrayRef]) -> Self { @@ -30,7 +30,7 @@ impl ChunkedArray { } } -pub fn is_not_null(name: &str, chunks: &[ArrayRef]) -> BooleanChunked { +pub fn is_not_null(name: PlSmallStr, chunks: &[ArrayRef]) -> BooleanChunked { let chunks = chunks.iter().map(|arr| { let bitmap = arr .validity() @@ -41,7 +41,7 @@ pub fn is_not_null(name: &str, chunks: &[ArrayRef]) -> BooleanChunked { BooleanChunked::from_chunk_iter(name, chunks) } -pub fn is_null(name: &str, chunks: &[ArrayRef]) -> BooleanChunked { +pub fn is_null(name: PlSmallStr, chunks: &[ArrayRef]) -> BooleanChunked { let chunks = chunks.iter().map(|arr| { let bitmap = arr .validity() @@ -52,7 +52,7 @@ pub fn is_null(name: &str, chunks: &[ArrayRef]) -> BooleanChunked { BooleanChunked::from_chunk_iter(name, chunks) } -pub fn replace_non_null(name: &str, chunks: &[ArrayRef], default: bool) -> BooleanChunked { +pub fn replace_non_null(name: PlSmallStr, chunks: &[ArrayRef], default: bool) -> BooleanChunked { BooleanChunked::from_chunk_iter( name, chunks.iter().map(|el| { diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs index 9d3b0938f390..0436b2264c5c 100644 --- a/crates/polars-core/src/chunked_array/ops/reverse.rs +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -15,7 +15,7 @@ where } else { self.into_iter().rev().collect_trusted() }; - out.rename(self.name()); + out.rename(self.name().clone()); match self.is_sorted_flag() { IsSorted::Ascending => out.set_sorted_flag(IsSorted::Descending), @@ -32,7 +32,7 @@ macro_rules! impl_reverse { impl ChunkReverse for $ca_type { fn reverse(&self) -> Self { let mut ca: Self = self.into_iter().rev().collect_trusted(); - ca.rename(self.name()); + ca.rename(self.name().clone()); ca } } @@ -51,7 +51,7 @@ impl ChunkReverse for BinaryChunked { unsafe { let arr = BinaryViewArray::new_unchecked( - arr.data_type().clone(), + arr.dtype().clone(), views.into(), arr.data_buffers().clone(), arr.validity().map(|bitmap| bitmap.iter().rev().collect()), @@ -60,13 +60,16 @@ impl ChunkReverse for BinaryChunked { ) .boxed(); BinaryChunked::from_chunks_and_dtype_unchecked( - self.name(), + self.name().clone(), vec![arr], self.dtype().clone(), ) } } else { - let ca = IdxCa::from_vec("", (0..self.len() as IdxSize).rev().collect()); + let ca = IdxCa::from_vec( + PlSmallStr::EMPTY, + (0..self.len() as IdxSize).rev().collect(), + ); unsafe { self.take_unchecked(&ca) } } } @@ -89,7 +92,7 @@ impl ChunkReverse for ArrayChunked { let values = arr.values().as_ref(); let mut builder = - get_fixed_size_list_builder(ca.inner_dtype(), ca.len(), ca.width(), ca.name()) + get_fixed_size_list_builder(ca.inner_dtype(), ca.len(), ca.width(), ca.name().clone()) .expect("not yet supported"); // SAFETY, we are within bounds @@ -117,6 +120,12 @@ impl ChunkReverse for ArrayChunked { impl ChunkReverse for ObjectChunked { fn reverse(&self) -> Self { // SAFETY: we know we don't go out of bounds. - unsafe { self.take_unchecked(&(0..self.len() as IdxSize).rev().collect_ca("")) } + unsafe { + self.take_unchecked( + &(0..self.len() as IdxSize) + .rev() + .collect_ca(PlSmallStr::EMPTY), + ) + } } } diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index c5898edb4df1..fd6e6ff491c9 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -106,14 +106,15 @@ mod inner_mod { let len = self.len(); let arr = ca.downcast_iter().next().unwrap(); - let mut ca = ChunkedArray::::from_slice("", &[T::Native::zero()]); + let mut ca = ChunkedArray::::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]); let ptr = ca.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray; let mut series_container = ca.into_series(); - let mut builder = PrimitiveChunkedBuilder::::new(self.name(), self.len()); + let mut builder = PrimitiveChunkedBuilder::::new(self.name().clone(), self.len()); if let Some(weights) = options.weights { - let weights_series = Float64Chunked::new("weights", &weights).into_series(); + let weights_series = + Float64Chunked::new(PlSmallStr::from_static("weights"), &weights).into_series(); let weights_series = weights_series.cast(self.dtype()).unwrap(); @@ -221,7 +222,7 @@ mod inner_mod { F: FnMut(&mut ChunkedArray) -> Option, { if window_size > self.len() { - return Ok(Self::full_null(self.name(), self.len())); + return Ok(Self::full_null(self.name().clone(), self.len())); } let ca = self.rechunk(); let arr = ca.downcast_iter().next().unwrap(); @@ -229,7 +230,8 @@ mod inner_mod { // We create a temporary dummy ChunkedArray. This will be a // container where we swap the window contents every iteration doing // so will save a lot of heap allocations. - let mut heap_container = ChunkedArray::::from_slice("", &[T::Native::zero()]); + let mut heap_container = + ChunkedArray::::from_slice(PlSmallStr::EMPTY, &[T::Native::zero()]); let ptr = heap_container.chunks[0].as_mut() as *mut dyn Array as *mut PrimitiveArray; @@ -274,7 +276,7 @@ mod inner_mod { values.into(), Some(validity.into()), ); - Ok(Self::with_chunk(self.name(), arr)) + Ok(Self::with_chunk(self.name().clone(), arr)) } } } diff --git a/crates/polars-core/src/chunked_array/ops/search_sorted.rs b/crates/polars-core/src/chunked_array/ops/search_sorted.rs index e31599429aae..5e97f0818176 100644 --- a/crates/polars-core/src/chunked_array/ops/search_sorted.rs +++ b/crates/polars-core/src/chunked_array/ops/search_sorted.rs @@ -38,8 +38,10 @@ where } /// Search through a series of chunks for the first position where f(x) is true, -/// assuming it is first always false and then always true. It repeats this for -/// each value in search_values. If the search value is null null_idx is returned. +/// assuming it is first always false and then always true. +/// +/// It repeats this for each value in search_values. If the search value is null null_idx is +/// returned. /// /// Assumes the chunks are non-empty. pub fn lower_bound_chunks<'a, T, F>( diff --git a/crates/polars-core/src/chunked_array/ops/set.rs b/crates/polars-core/src/chunked_array/ops/set.rs index 77aba5673ea2..5717cacae98e 100644 --- a/crates/polars-core/src/chunked_array/ops/set.rs +++ b/crates/polars-core/src/chunked_array/ops/set.rs @@ -57,7 +57,7 @@ where value, T::get_dtype().to_arrow(CompatLevel::newest()), )?; - return Ok(Self::with_chunk(self.name(), arr)); + return Ok(Self::with_chunk(self.name().clone(), arr)); } // Other fast path. Slightly slower as it does not do a memcpy. else { @@ -71,7 +71,7 @@ where *val = value; Ok(()) })?; - return Ok(Self::from_vec(self.name(), av)); + return Ok(Self::from_vec(self.name().clone(), av)); } } } @@ -86,7 +86,7 @@ where where F: Fn(Option) -> Option, { - let mut builder = PrimitiveChunkedBuilder::::new(self.name(), self.len()); + let mut builder = PrimitiveChunkedBuilder::::new(self.name().clone(), self.len()); impl_scatter_with!(self, builder, idx, f) } @@ -109,7 +109,7 @@ where T::get_dtype().to_arrow(CompatLevel::newest()), ) }); - Ok(ChunkedArray::from_chunk_iter(self.name(), chunks)) + Ok(ChunkedArray::from_chunk_iter(self.name().clone(), chunks)) } else { // slow path, could be optimized. let ca = mask @@ -120,7 +120,7 @@ where _ => opt_val, }) .collect_trusted::() - .with_name(self.name()); + .with_name(self.name().clone()); Ok(ca) } } @@ -160,7 +160,7 @@ impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked { validity.set(i, f(input).unwrap_or(false)); } let arr = BooleanArray::from_data_default(values.into(), Some(validity.into())); - Ok(BooleanChunked::with_chunk(self.name(), arr)) + Ok(BooleanChunked::with_chunk(self.name().clone(), arr)) } fn set(&'a self, mask: &BooleanChunked, value: Option) -> PolarsResult { @@ -173,7 +173,7 @@ impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked { _ => opt_val, }) .collect_trusted::() - .with_name(self.name()); + .with_name(self.name().clone()); Ok(ca) } } @@ -189,7 +189,7 @@ impl<'a> ChunkSet<'a, &'a str, String> for StringChunked { { let idx_iter = idx.into_iter(); let mut ca_iter = self.into_iter().enumerate(); - let mut builder = StringChunkedBuilder::new(self.name(), self.len()); + let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len()); for current_idx in idx_iter.into_iter().map(|i| i as usize) { polars_ensure!(current_idx < self.len(), oob = current_idx, self.len()); @@ -220,7 +220,7 @@ impl<'a> ChunkSet<'a, &'a str, String> for StringChunked { Self: Sized, F: Fn(Option<&'a str>) -> Option, { - let mut builder = StringChunkedBuilder::new(self.name(), self.len()); + let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len()); impl_scatter_with!(self, builder, idx, f) } @@ -237,7 +237,7 @@ impl<'a> ChunkSet<'a, &'a str, String> for StringChunked { _ => opt_val, }) .collect_trusted::() - .with_name(self.name()); + .with_name(self.name().clone()); Ok(ca) } } @@ -252,7 +252,7 @@ impl<'a> ChunkSet<'a, &'a [u8], Vec> for BinaryChunked { Self: Sized, { let mut ca_iter = self.into_iter().enumerate(); - let mut builder = BinaryChunkedBuilder::new(self.name(), self.len()); + let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len()); for current_idx in idx.into_iter().map(|i| i as usize) { polars_ensure!(current_idx < self.len(), oob = current_idx, self.len()); @@ -283,7 +283,7 @@ impl<'a> ChunkSet<'a, &'a [u8], Vec> for BinaryChunked { Self: Sized, F: Fn(Option<&'a [u8]>) -> Option>, { - let mut builder = BinaryChunkedBuilder::new(self.name(), self.len()); + let mut builder = BinaryChunkedBuilder::new(self.name().clone(), self.len()); impl_scatter_with!(self, builder, idx, f) } @@ -300,7 +300,7 @@ impl<'a> ChunkSet<'a, &'a [u8], Vec> for BinaryChunked { _ => opt_val, }) .collect_trusted::() - .with_name(self.name()); + .with_name(self.name().clone()); Ok(ca) } } @@ -311,23 +311,26 @@ mod test { #[test] fn test_set() { - let ca = Int32Chunked::new("a", &[1, 2, 3]); - let mask = BooleanChunked::new("mask", &[false, true, false]); + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]); let ca = ca.set(&mask, Some(5)).unwrap(); assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]); - let ca = Int32Chunked::new("a", &[1, 2, 3]); - let mask = BooleanChunked::new("mask", &[None, Some(true), None]); + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, Some(true), None]); let ca = ca.set(&mask, Some(5)).unwrap(); assert_eq!(Vec::from(&ca), &[Some(1), Some(5), Some(3)]); - let ca = Int32Chunked::new("a", &[1, 2, 3]); - let mask = BooleanChunked::new("mask", &[None, None, None]); + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[None, None, None]); let ca = ca.set(&mask, Some(5)).unwrap(); assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]); - let ca = Int32Chunked::new("a", &[1, 2, 3]); - let mask = BooleanChunked::new("mask", &[Some(true), Some(false), None]); + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let mask = BooleanChunked::new( + PlSmallStr::from_static("mask"), + &[Some(true), Some(false), None], + ); let ca = ca.set(&mask, Some(5)).unwrap(); assert_eq!(Vec::from(&ca), &[Some(5), Some(2), Some(3)]); @@ -337,30 +340,39 @@ mod test { assert!(ca.scatter_single(vec![0, 10], Some(0)).is_err()); // test booleans - let ca = BooleanChunked::new("a", &[true, true, true]); - let mask = BooleanChunked::new("mask", &[false, true, false]); + let ca = BooleanChunked::new(PlSmallStr::from_static("a"), &[true, true, true]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]); let ca = ca.set(&mask, None).unwrap(); assert_eq!(Vec::from(&ca), &[Some(true), None, Some(true)]); // test string - let ca = StringChunked::new("a", &["foo", "foo", "foo"]); - let mask = BooleanChunked::new("mask", &[false, true, false]); + let ca = StringChunked::new(PlSmallStr::from_static("a"), &["foo", "foo", "foo"]); + let mask = BooleanChunked::new(PlSmallStr::from_static("mask"), &[false, true, false]); let ca = ca.set(&mask, Some("bar")).unwrap(); assert_eq!(Vec::from(&ca), &[Some("foo"), Some("bar"), Some("foo")]); } #[test] fn test_set_null_values() { - let ca = Int32Chunked::new("a", &[Some(1), None, Some(3)]); - let mask = BooleanChunked::new("mask", &[Some(false), Some(true), None]); + let ca = Int32Chunked::new(PlSmallStr::from_static("a"), &[Some(1), None, Some(3)]); + let mask = BooleanChunked::new( + PlSmallStr::from_static("mask"), + &[Some(false), Some(true), None], + ); let ca = ca.set(&mask, Some(2)).unwrap(); assert_eq!(Vec::from(&ca), &[Some(1), Some(2), Some(3)]); - let ca = StringChunked::new("a", &[Some("foo"), None, Some("bar")]); + let ca = StringChunked::new( + PlSmallStr::from_static("a"), + &[Some("foo"), None, Some("bar")], + ); let ca = ca.set(&mask, Some("foo")).unwrap(); assert_eq!(Vec::from(&ca), &[Some("foo"), Some("foo"), Some("bar")]); - let ca = BooleanChunked::new("a", &[Some(false), None, Some(true)]); + let ca = BooleanChunked::new( + PlSmallStr::from_static("a"), + &[Some(false), None, Some(true)], + ); let ca = ca.set(&mask, Some(true)).unwrap(); assert_eq!(Vec::from(&ca), &[Some(false), Some(true), Some(true)]); } diff --git a/crates/polars-core/src/chunked_array/ops/shift.rs b/crates/polars-core/src/chunked_array/ops/shift.rs index e01bd3ab6945..09f71a5ed5d6 100644 --- a/crates/polars-core/src/chunked_array/ops/shift.rs +++ b/crates/polars-core/src/chunked_array/ops/shift.rs @@ -1,6 +1,7 @@ use num_traits::{abs, clamp}; use crate::prelude::*; +use crate::series::implementations::null::NullChunked; macro_rules! impl_shift_fill { ($self:ident, $periods:expr, $fill_value:expr) => {{ @@ -8,8 +9,8 @@ macro_rules! impl_shift_fill { if fill_length >= $self.len() { return match $fill_value { - Some(fill) => Self::full($self.name(), fill, $self.len()), - None => Self::full_null($self.name(), $self.len()), + Some(fill) => Self::full($self.name().clone(), fill, $self.len()), + None => Self::full_null($self.name().clone(), $self.len()), }; } let slice_offset = (-$periods).max(0) as i64; @@ -17,15 +18,15 @@ macro_rules! impl_shift_fill { let mut slice = $self.slice(slice_offset, length); let mut fill = match $fill_value { - Some(val) => Self::full($self.name(), val, fill_length), - None => Self::full_null($self.name(), fill_length), + Some(val) => Self::full($self.name().clone(), val, fill_length), + None => Self::full_null($self.name().clone(), fill_length), }; if $periods < 0 { - slice.append(&fill); + slice.append(&fill).unwrap(); slice } else { - fill.append(&slice); + fill.append(&slice).unwrap(); fill } }}; @@ -111,8 +112,12 @@ impl ChunkShiftFill> for ListChunked { let fill_length = abs(periods) as usize; let mut fill = match fill_value { - Some(val) => Self::full(self.name(), val, fill_length), - None => ListChunked::full_null_with_dtype(self.name(), fill_length, self.inner_dtype()), + Some(val) => Self::full(self.name().clone(), val, fill_length), + None => ListChunked::full_null_with_dtype( + self.name().clone(), + fill_length, + self.inner_dtype(), + ), }; if periods < 0 { @@ -143,10 +148,13 @@ impl ChunkShiftFill> for ArrayChunked { let fill_length = abs(periods) as usize; let mut fill = match fill_value { - Some(val) => Self::full(self.name(), val, fill_length), - None => { - ArrayChunked::full_null_with_dtype(self.name(), fill_length, self.inner_dtype(), 0) - }, + Some(val) => Self::full(self.name().clone(), val, fill_length), + None => ArrayChunked::full_null_with_dtype( + self.name().clone(), + fill_length, + self.inner_dtype(), + 0, + ), }; if periods < 0 { @@ -183,13 +191,41 @@ impl ChunkShift> for ObjectChunked { } } +#[cfg(feature = "dtype-struct")] +impl ChunkShift for StructChunked { + fn shift(&self, periods: i64) -> ChunkedArray { + // This has its own implementation because a ArrayChunked cannot have a full-null without + // knowing the inner type + let periods = clamp(periods, -(self.len() as i64), self.len() as i64); + let slice_offset = (-periods).max(0); + let length = self.len() - abs(periods) as usize; + let mut slice = self.slice(slice_offset, length); + + let fill_length = abs(periods) as usize; + + // Go via null, so the cast creates the proper struct type. + let fill = NullChunked::new(self.name().clone(), fill_length) + .cast(self.dtype(), Default::default()) + .unwrap(); + let mut fill = fill.struct_().unwrap().clone(); + + if periods < 0 { + slice.append(&fill).unwrap(); + slice + } else { + fill.append(&slice).unwrap(); + fill + } + } +} + #[cfg(test)] mod test { use crate::prelude::*; #[test] fn test_shift() { - let ca = Int32Chunked::new("", &[1, 2, 3]); + let ca = Int32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3]); // shift by 0, 1, 2, 3, 4 let shifted = ca.shift_and_fill(0, Some(5)); @@ -222,7 +258,7 @@ mod test { assert_eq!(Vec::from(&shifted), &[Some(3), None, None]); // string - let s = Series::new("a", ["a", "b", "c"]); + let s = Series::new(PlSmallStr::from_static("a"), ["a", "b", "c"]); let shifted = s.shift(-1); assert_eq!( Vec::from(shifted.str().unwrap()), diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs index 300e9f85611e..e774c8ba51f3 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -1,4 +1,4 @@ -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use super::*; diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs index 724adbebe818..ca34d37318a7 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs @@ -14,7 +14,7 @@ where } pub(super) fn arg_sort( - name: &str, + name: PlSmallStr, iters: I, options: SortOptions, null_count: usize, @@ -69,7 +69,7 @@ where } pub(super) fn arg_sort_no_nulls( - name: &str, + name: PlSmallStr, iters: I, options: SortOptions, len: usize, diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index fd9806f27ada..d659ebab7e69 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -1,7 +1,7 @@ use arrow::compute::utils::combine_validities_and_many; use compare_inner::NullOrderCmp; use polars_row::{convert_columns, EncodingField, RowsEncoded}; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use super::*; use crate::utils::_split_offsets; @@ -100,7 +100,8 @@ pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { ca.physical().chunks[0].clone() } }, - _ => by.to_arrow(0, CompatLevel::newest()), + // Take physical + _ => by.chunks()[0].clone(), }; Ok(out) } @@ -120,7 +121,10 @@ pub fn encode_rows_vertical_par_unordered(by: &[Series]) -> PolarsResult>>()); - Ok(BinaryOffsetChunked::from_chunk_iter("", chunks?)) + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) } // Almost the same but broadcast nulls to the row-encoded array. @@ -155,12 +159,18 @@ pub fn encode_rows_vertical_par_unordered_broadcast_nulls( }); let chunks = POOL.install(|| chunks.collect::>>()); - Ok(BinaryOffsetChunked::from_chunk_iter("", chunks?)) + Ok(BinaryOffsetChunked::from_chunk_iter( + PlSmallStr::EMPTY, + chunks?, + )) } pub(crate) fn encode_rows_unordered(by: &[Series]) -> PolarsResult { let rows = _get_rows_encoded_unordered(by)?; - Ok(BinaryOffsetChunked::with_chunk("", rows.into_array())) + Ok(BinaryOffsetChunked::with_chunk( + PlSmallStr::EMPTY, + rows.into_array(), + )) } pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { @@ -169,7 +179,7 @@ pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { for by in by { let arr = _get_rows_encoded_compat_array(by)?; let field = EncodingField::new_unsorted(); - match arr.data_type() { + match arr.dtype() { // Flatten the struct fields. ArrowDataType::Struct(_) => { let arr = arr.as_any().downcast_ref::().unwrap(); @@ -205,7 +215,7 @@ pub fn _get_rows_encoded( nulls_last: *null_last, no_order: false, }; - match arr.data_type() { + match arr.dtype() { // Flatten the struct fields. ArrowDataType::Struct(_) => { let arr = arr.as_any().downcast_ref::().unwrap(); @@ -225,7 +235,7 @@ pub fn _get_rows_encoded( } pub fn _get_rows_encoded_ca( - name: &str, + name: PlSmallStr, by: &[Series], descending: &[bool], nulls_last: &[bool], @@ -243,7 +253,7 @@ pub fn _get_rows_encoded_arr( } pub fn _get_rows_encoded_ca_unordered( - name: &str, + name: PlSmallStr, by: &[Series], ) -> PolarsResult { _get_rows_encoded_unordered(by) diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index afc80026313b..0dcb2cb84b51 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -19,7 +19,7 @@ impl CategoricalChunked { let cats: UInt32Chunked = vals .into_iter() .map(|(idx, _v)| idx) - .collect_ca_trusted(self.name()); + .collect_ca_trusted(self.name().clone()); // SAFETY: // we only reordered the indexes so we are still in bounds @@ -61,7 +61,7 @@ impl CategoricalChunked { if self.uses_lexical_ordering() { let iters = [self.iter_str()]; arg_sort::arg_sort( - self.name(), + self.name().clone(), iters, options, self.physical().null_count(), @@ -124,7 +124,7 @@ mod test { enable_string_cache(); } - let s = Series::new("", init) + let s = Series::new(PlSmallStr::EMPTY, init) .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?; let ca = s.categorical()?; let ca_lexical = ca.clone(); @@ -132,7 +132,8 @@ mod test { let out = ca_lexical.sort(false); assert_order(&out, &["a", "b", "c", "d"]); - let s = Series::new("", init).cast(&DataType::Categorical(None, Default::default()))?; + let s = Series::new(PlSmallStr::EMPTY, init) + .cast(&DataType::Categorical(None, Default::default()))?; let ca = s.categorical()?; let out = ca.sort(false); @@ -159,7 +160,7 @@ mod test { enable_string_cache(); } - let s = Series::new("", init) + let s = Series::new(PlSmallStr::EMPTY, init) .cast(&DataType::Categorical(None, CategoricalOrdering::Lexical))?; let ca = s.categorical()?; let ca_lexical: CategoricalChunked = ca.clone(); diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index eb16506d5127..1c1940b6f10d 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -165,7 +165,7 @@ where sort_impl_unstable(vals.as_mut_slice(), options); - let mut ca = ChunkedArray::from_vec(ca.name(), vals); + let mut ca = ChunkedArray::from_vec(ca.name().clone(), vals); let s = if options.descending { IsSorted::Descending } else { @@ -205,7 +205,7 @@ where vals.into(), Some(create_validity(len, null_count, options.nulls_last)), ); - let mut new_ca = ChunkedArray::with_chunk(ca.name(), arr); + let mut new_ca = ChunkedArray::with_chunk(ca.name().clone(), arr); let s = if options.descending { IsSorted::Descending } else { @@ -225,12 +225,12 @@ where let iter = ca .downcast_iter() .map(|arr| arr.values().as_slice().iter().copied()); - arg_sort::arg_sort_no_nulls(ca.name(), iter, options, ca.len()) + arg_sort::arg_sort_no_nulls(ca.name().clone(), iter, options, ca.len()) } else { let iter = ca .downcast_iter() .map(|arr| arr.iter().map(|opt| opt.copied())); - arg_sort::arg_sort(ca.name(), iter, options, ca.null_count(), ca.len()) + arg_sort::arg_sort(ca.name().clone(), iter, options, ca.null_count(), ca.len()) } } @@ -409,14 +409,14 @@ impl ChunkSort for BinaryChunked { fn arg_sort(&self, options: SortOptions) -> IdxCa { if self.null_count() == 0 { arg_sort::arg_sort_no_nulls( - self.name(), + self.name().clone(), self.downcast_iter().map(|arr| arr.values_iter()), options, self.len(), ) } else { arg_sort::arg_sort( - self.name(), + self.name().clone(), self.downcast_iter().map(|arr| arr.iter()), options, self.null_count(), @@ -477,7 +477,7 @@ impl ChunkSort for BinaryOffsetChunked { let arr = unsafe { BinaryArray::from_data_unchecked_default(offsets.into(), values.into(), None) }; - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) }, (_, true) => { for val in v { @@ -495,7 +495,7 @@ impl ChunkSort for BinaryOffsetChunked { Some(create_validity(len, null_count, true)), ) }; - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) }, (_, false) => { offsets.extend(std::iter::repeat(length_so_far).take(null_count)); @@ -514,7 +514,7 @@ impl ChunkSort for BinaryOffsetChunked { Some(create_validity(len, null_count, false)), ) }; - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) }, }; @@ -552,13 +552,16 @@ impl ChunkSort for BinaryOffsetChunked { if self.null_count() == 0 { argsort(&mut idx); - IdxCa::from_vec(self.name(), idx) + IdxCa::from_vec(self.name().clone(), idx) } else { // This branch (almost?) never gets called as the row-encoding also encodes nulls. let (partitioned_part, validity) = partition_nulls(&mut idx, arr.validity().cloned(), options); argsort(partitioned_part); - IdxCa::with_chunk(self.name(), IdxArr::from_data_default(idx.into(), validity)) + IdxCa::with_chunk( + self.name().clone(), + IdxArr::from_data_default(idx.into(), validity), + ) } } @@ -595,7 +598,7 @@ impl ChunkSort for BinaryOffsetChunked { impl StructChunked { pub(crate) fn arg_sort(&self, options: SortOptions) -> IdxCa { let bin = _get_rows_encoded_ca( - self.name(), + self.name().clone(), &[self.clone().into_series()], &[options.descending], &[options.nulls_last], @@ -656,7 +659,7 @@ impl ChunkSort for BooleanChunked { } let mut ca: BooleanChunked = vals.into_iter().collect_trusted(); - ca.rename(self.name()); + ca.rename(self.name().clone()); ca } @@ -672,14 +675,14 @@ impl ChunkSort for BooleanChunked { fn arg_sort(&self, options: SortOptions) -> IdxCa { if self.null_count() == 0 { arg_sort::arg_sort_no_nulls( - self.name(), + self.name().clone(), self.downcast_iter().map(|arr| arr.values_iter()), options, self.len(), ) } else { arg_sort::arg_sort( - self.name(), + self.name().clone(), self.downcast_iter().map(|arr| arr.iter()), options, self.null_count(), @@ -721,7 +724,7 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult .iter() .map(convert_sort_column_multi_sort) .collect::>>()?; - let mut out = StructChunked::from_series(ca.name(), &new_fields)?; + let mut out = StructChunked::from_series(ca.name().clone(), &new_fields)?; out.zip_outer_validity(ca); out.into_series() }, @@ -775,7 +778,7 @@ mod test { #[test] fn test_arg_sort() { let a = Int32Chunked::new( - "a", + PlSmallStr::from_static("a"), &[ Some(1), // 0 Some(5), // 1 @@ -809,7 +812,7 @@ mod test { #[test] fn test_sort() { let a = Int32Chunked::new( - "a", + PlSmallStr::from_static("a"), &[ Some(1), Some(5), @@ -859,7 +862,10 @@ mod test { None ] ); - let b = BooleanChunked::new("b", &[Some(false), Some(true), Some(false)]); + let b = BooleanChunked::new( + PlSmallStr::from_static("b"), + &[Some(false), Some(true), Some(false)], + ); let out = b.sort_with(SortOptions::default().with_order_descending(true)); assert_eq!(Vec::from(&out), &[Some(true), Some(false), Some(false)]); let out = b.sort_with(SortOptions::default().with_order_descending(false)); @@ -869,9 +875,12 @@ mod test { #[test] #[cfg_attr(miri, ignore)] fn test_arg_sort_multiple() -> PolarsResult<()> { - let a = Int32Chunked::new("a", &[1, 2, 1, 1, 3, 4, 3, 3]); - let b = Int64Chunked::new("b", &[0, 1, 2, 3, 4, 5, 6, 1]); - let c = StringChunked::new("c", &["a", "b", "c", "d", "e", "f", "g", "h"]); + let a = Int32Chunked::new(PlSmallStr::from_static("a"), &[1, 2, 1, 1, 3, 4, 3, 3]); + let b = Int64Chunked::new(PlSmallStr::from_static("b"), &[0, 1, 2, 3, 4, 5, 6, 1]); + let c = StringChunked::new( + PlSmallStr::from_static("c"), + &["a", "b", "c", "d", "e", "f", "g", "h"], + ); let df = DataFrame::new(vec![a.into_series(), b.into_series(), c.into_series()])?; let out = df.sort(["a", "b", "c"], SortMultipleOptions::default())?; @@ -890,8 +899,12 @@ mod test { ); // now let the first sort be a string - let a = StringChunked::new("a", &["a", "b", "c", "a", "b", "c"]).into_series(); - let b = Int32Chunked::new("b", &[5, 4, 2, 3, 4, 5]).into_series(); + let a = StringChunked::new( + PlSmallStr::from_static("a"), + &["a", "b", "c", "a", "b", "c"], + ) + .into_series(); + let b = Int32Chunked::new(PlSmallStr::from_static("b"), &[5, 4, 2, 3, 4, 5]).into_series(); let df = DataFrame::new(vec![a, b])?; let out = df.sort(["a", "b"], SortMultipleOptions::default())?; @@ -931,7 +944,10 @@ mod test { #[test] fn test_sort_string() { - let ca = StringChunked::new("a", &[Some("a"), None, Some("c"), None, Some("b")]); + let ca = StringChunked::new( + PlSmallStr::from_static("a"), + &[Some("a"), None, Some("c"), None, Some("b")], + ); let out = ca.sort_with(SortOptions { descending: false, nulls_last: false, @@ -970,7 +986,10 @@ mod test { assert_eq!(Vec::from(&out), expected); // no nulls - let ca = StringChunked::new("a", &[Some("a"), Some("c"), Some("b")]); + let ca = StringChunked::new( + PlSmallStr::from_static("a"), + &[Some("a"), Some("c"), Some("b")], + ); let out = ca.sort(false); let expected = &[Some("a"), Some("b"), Some("c")]; assert_eq!(Vec::from(&out), expected); diff --git a/crates/polars-core/src/chunked_array/ops/sort/options.rs b/crates/polars-core/src/chunked_array/ops/sort/options.rs index 8726da26774a..046d0b251b04 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/options.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/options.rs @@ -12,7 +12,7 @@ use crate::prelude::*; /// /// ``` /// # use polars_core::prelude::*; -/// let s = Series::new("a", [Some(5), Some(2), Some(3), Some(4), None].as_ref()); +/// let s = Series::new("a".into(), [Some(5), Some(2), Some(3), Some(4), None].as_ref()); /// let sorted = s /// .sort( /// SortOptions::default() @@ -23,7 +23,7 @@ use crate::prelude::*; /// .unwrap(); /// assert_eq!( /// sorted, -/// Series::new("a", [Some(5), Some(4), Some(3), Some(2), None].as_ref()) +/// Series::new("a".into(), [Some(5), Some(4), Some(3), Some(2), None].as_ref()) /// ); /// ``` #[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)] diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index 40e16f08ed95..b645088b4d68 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -116,7 +116,7 @@ where } let arr: PrimitiveArray = arr.into(); - Ok(ChunkedArray::with_chunk(self.name(), arr)) + Ok(ChunkedArray::with_chunk(self.name().clone(), arr)) } else { let mask = self.not_equal_missing(&self.shift(1)); self.filter(&mask) @@ -126,16 +126,12 @@ where if !T::Native::is_float() && MetadataEnv::experimental_enabled() { let md = self.metadata(); if let (Some(min), Some(max)) = (md.get_min_value(), md.get_max_value()) { - let data_type = self - .field - .as_ref() - .data_type() - .to_arrow(CompatLevel::oldest()); + let dtype = self.field.as_ref().dtype().to_arrow(CompatLevel::oldest()); if let Some(mut state) = PrimitiveRangedUniqueState::new( *min, *max, self.null_count() > 0, - data_type, + dtype, ) { use polars_compute::unique::RangedUniqueKernel; @@ -149,7 +145,7 @@ where let unique = state.finalize_unique(); - return Ok(Self::with_chunk(self.name(), unique)); + return Ok(Self::with_chunk(self.name().clone(), unique)); } } } @@ -161,7 +157,7 @@ where } fn arg_unique(&self) -> PolarsResult { - Ok(IdxCa::from_vec(self.name(), arg_unique_ca!(self))) + Ok(IdxCa::from_vec(self.name().clone(), arg_unique_ca!(self))) } fn n_unique(&self) -> PolarsResult { @@ -230,7 +226,7 @@ impl ChunkUnique for BinaryChunked { set.extend(arr.values_iter()) } Ok(BinaryChunked::from_iter_values( - self.name(), + self.name().clone(), set.iter().copied(), )) }, @@ -241,7 +237,7 @@ impl ChunkUnique for BinaryChunked { set.extend(arr.iter()) } Ok(BinaryChunked::from_iter_options( - self.name(), + self.name().clone(), set.iter().copied(), )) }, @@ -249,7 +245,7 @@ impl ChunkUnique for BinaryChunked { } fn arg_unique(&self) -> PolarsResult { - Ok(IdxCa::from_vec(self.name(), arg_unique_ca!(self))) + Ok(IdxCa::from_vec(self.name().clone(), arg_unique_ca!(self))) } fn n_unique(&self) -> PolarsResult { @@ -272,13 +268,9 @@ impl ChunkUnique for BooleanChunked { fn unique(&self) -> PolarsResult { use polars_compute::unique::RangedUniqueKernel; - let data_type = self - .field - .as_ref() - .data_type() - .to_arrow(CompatLevel::oldest()); + let dtype = self.field.as_ref().dtype().to_arrow(CompatLevel::oldest()); let has_null = self.null_count() > 0; - let mut state = BooleanUniqueKernelState::new(has_null, data_type); + let mut state = BooleanUniqueKernelState::new(has_null, dtype); for arr in self.downcast_iter() { state.append(arr); @@ -290,11 +282,11 @@ impl ChunkUnique for BooleanChunked { let unique = state.finalize_unique(); - Ok(Self::with_chunk(self.name(), unique)) + Ok(Self::with_chunk(self.name().clone(), unique)) } fn arg_unique(&self) -> PolarsResult { - Ok(IdxCa::from_vec(self.name(), arg_unique_ca!(self))) + Ok(IdxCa::from_vec(self.name().clone(), arg_unique_ca!(self))) } } @@ -304,7 +296,8 @@ mod test { #[test] fn unique() { - let ca = ChunkedArray::::from_slice("a", &[1, 2, 3, 2, 1]); + let ca = + ChunkedArray::::from_slice(PlSmallStr::from_static("a"), &[1, 2, 3, 2, 1]); assert_eq!( ca.unique() .unwrap() @@ -313,13 +306,16 @@ mod test { .collect::>(), vec![Some(1), Some(2), Some(3)] ); - let ca = BooleanChunked::from_slice("a", &[true, false, true]); + let ca = BooleanChunked::from_slice(PlSmallStr::from_static("a"), &[true, false, true]); assert_eq!( ca.unique().unwrap().into_iter().collect::>(), vec![Some(false), Some(true)] ); - let ca = StringChunked::new("", &[Some("a"), None, Some("a"), Some("b"), None]); + let ca = StringChunked::new( + PlSmallStr::EMPTY, + &[Some("a"), None, Some("a"), Some("b"), None], + ); assert_eq!( Vec::from(&ca.unique().unwrap().sort(false)), &[None, Some("a"), Some("b")] @@ -328,7 +324,8 @@ mod test { #[test] fn arg_unique() { - let ca = ChunkedArray::::from_slice("a", &[1, 2, 1, 1, 3]); + let ca = + ChunkedArray::::from_slice(PlSmallStr::from_static("a"), &[1, 2, 1, 1, 3]); assert_eq!( ca.arg_unique().unwrap().into_iter().collect::>(), vec![Some(0), Some(1), Some(4)] diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 8319c81d9c3c..eb24468d892d 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -1,6 +1,6 @@ use arrow::bitmap::Bitmap; use arrow::compute::utils::{combine_validities_and, combine_validities_and_not}; -use polars_compute::if_then_else::IfThenElseKernel; +use polars_compute::if_then_else::{if_then_else_validity, IfThenElseKernel}; #[cfg(feature = "object")] use crate::chunked_array::object::ObjectArray; @@ -26,7 +26,7 @@ where (1, other_len) => src.new_from_index(0, other_len), _ => polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR), }; - Ok(ret.with_name(if_true.name())) + Ok(ret.with_name(if_true.name().clone())) } fn bool_null_to_false(mask: &BooleanArray) -> Bitmap { @@ -62,7 +62,7 @@ fn combine_validities_chunked< impl ChunkZip for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, T::Array: for<'a> IfThenElseKernel = T::Physical<'a>>, ChunkedArray: ChunkExpandAtIndex, { @@ -94,7 +94,7 @@ where combine_validities_and, ), (Some(t), Some(f)) => { - let dtype = if_true.downcast_iter().next().unwrap().data_type(); + let dtype = if_true.downcast_iter().next().unwrap().dtype(); let chunks = mask.downcast_iter().map(|m| { let bm = bool_null_to_false(m); let t = t.clone(); @@ -156,7 +156,7 @@ where polars_bail!(ShapeMismatch: SHAPE_MISMATCH_STR) }; - Ok(ret.with_name(if_true.name())) + Ok(ret.with_name(if_true.name().clone())) } } @@ -206,3 +206,104 @@ impl IfThenElseKernel for ObjectArray { .collect_arr() } } + +#[cfg(feature = "dtype-struct")] +impl ChunkZip for StructChunked { + fn zip_with( + &self, + mask: &BooleanChunked, + other: &ChunkedArray, + ) -> PolarsResult> { + let (l, r, mask) = align_chunks_ternary(self, other, mask); + + // Prepare the boolean arrays such that Null maps to false. + // This prevents every field doing that. + // # SAFETY + // We don't modify the length and update the null count. + let mut mask = mask.into_owned(); + unsafe { + for arr in mask.downcast_iter_mut() { + let bm = bool_null_to_false(arr); + *arr = BooleanArray::from_data_default(bm, None); + } + mask.set_null_count(0); + } + + // Zip all the fields. + let fields = l + .fields_as_series() + .iter() + .zip(r.fields_as_series()) + .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs)) + .collect::>>()?; + + let mut out = StructChunked::from_series(self.name().clone(), &fields)?; + + // Zip the validities. + if (l.null_count + r.null_count) > 0 { + let validities = l + .chunks() + .iter() + .zip(r.chunks()) + .map(|(l, r)| (l.validity(), r.validity())); + + fn broadcast(v: Option<&Bitmap>, arr: &ArrayRef) -> Bitmap { + if v.unwrap().get(0).unwrap() { + Bitmap::new_with_value(true, arr.len()) + } else { + Bitmap::new_zeroed(arr.len()) + } + } + + // # SAFETY + // We don't modify the length and update the null count. + unsafe { + for ((arr, (lv, rv)), mask) in out + .chunks_mut() + .iter_mut() + .zip(validities) + .zip(mask.downcast_iter()) + { + // TODO! we can optimize this and use a kernel that is able to broadcast wo/ allocating. + let (lv, rv) = match (lv.map(|b| b.len()), rv.map(|b| b.len())) { + (Some(1), Some(1)) if arr.len() != 1 => { + let lv = broadcast(lv, arr); + let rv = broadcast(rv, arr); + (Some(lv), Some(rv)) + }, + (Some(a), Some(b)) if a == b => (lv.cloned(), rv.cloned()), + (Some(1), _) => { + let lv = broadcast(lv, arr); + (Some(lv), rv.cloned()) + }, + (_, Some(1)) => { + let rv = broadcast(rv, arr); + (lv.cloned(), Some(rv)) + }, + (None, Some(_)) | (Some(_), None) | (None, None) => { + (lv.cloned(), rv.cloned()) + }, + (Some(a), Some(b)) => { + polars_bail!(InvalidOperation: "got different sizes in 'zip' operation, got length: {a} and {b}") + }, + }; + + // broadcast mask + let validity = if mask.len() != arr.len() && mask.len() == 1 { + if mask.get(0).unwrap() { + lv + } else { + rv + } + } else { + if_then_else_validity(mask.values(), lv.as_ref(), rv.as_ref()) + }; + + *arr = arr.with_validity(validity); + } + } + out.compute_len(); + } + Ok(out) + } +} diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 18b1117669fc..94ab33f02cee 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -12,7 +12,7 @@ use crate::utils::NoNull; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { if len == 0 { - return IdxCa::new_vec("", vec![]); + return IdxCa::new_vec(PlSmallStr::EMPTY, vec![]); } let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); let dist = Uniform::new(0, len as IdxSize); @@ -45,7 +45,7 @@ fn create_rand_index_no_replacement( IndexVec::USize(v) => v.into_iter().map(|x| x as IdxSize).collect(), }; } - IdxCa::new_vec("", buf) + IdxCa::new_vec(PlSmallStr::EMPTY, buf) } impl ChunkedArray @@ -251,7 +251,12 @@ where T::Native: Float, { /// Create [`ChunkedArray`] with samples from a Normal distribution. - pub fn rand_normal(name: &str, length: usize, mean: f64, std_dev: f64) -> PolarsResult { + pub fn rand_normal( + name: PlSmallStr, + length: usize, + mean: f64, + std_dev: f64, + ) -> PolarsResult { let normal = Normal::new(mean, std_dev).map_err(to_compute_err)?; let mut builder = PrimitiveChunkedBuilder::::new(name, length); let mut rng = rand::thread_rng(); @@ -264,7 +269,7 @@ where } /// Create [`ChunkedArray`] with samples from a Standard Normal distribution. - pub fn rand_standard_normal(name: &str, length: usize) -> Self { + pub fn rand_standard_normal(name: PlSmallStr, length: usize) -> Self { let mut builder = PrimitiveChunkedBuilder::::new(name, length); let mut rng = rand::thread_rng(); for _ in 0..length { @@ -276,7 +281,7 @@ where } /// Create [`ChunkedArray`] with samples from a Uniform distribution. - pub fn rand_uniform(name: &str, length: usize, low: f64, high: f64) -> Self { + pub fn rand_uniform(name: PlSmallStr, length: usize, low: f64, high: f64) -> Self { let uniform = Uniform::new(low, high); let mut builder = PrimitiveChunkedBuilder::::new(name, length); let mut rng = rand::thread_rng(); @@ -291,7 +296,7 @@ where impl BooleanChunked { /// Create [`ChunkedArray`] with samples from a Bernoulli distribution. - pub fn rand_bernoulli(name: &str, length: usize, p: f64) -> PolarsResult { + pub fn rand_bernoulli(name: PlSmallStr, length: usize, p: f64) -> PolarsResult { let dist = Bernoulli::new(p).map_err(to_compute_err)?; let mut rng = rand::thread_rng(); let mut builder = BooleanChunkedBuilder::new(name, length); @@ -316,31 +321,71 @@ mod test { // Default samples are random and don't require seeds. assert!(df - .sample_n(&Series::new("s", &[3]), false, false, None) + .sample_n( + &Series::new(PlSmallStr::from_static("s"), &[3]), + false, + false, + None + ) .is_ok()); assert!(df - .sample_frac(&Series::new("frac", &[0.4]), false, false, None) + .sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[0.4]), + false, + false, + None + ) .is_ok()); // With seeding. assert!(df - .sample_n(&Series::new("s", &[3]), false, false, Some(0)) + .sample_n( + &Series::new(PlSmallStr::from_static("s"), &[3]), + false, + false, + Some(0) + ) .is_ok()); assert!(df - .sample_frac(&Series::new("frac", &[0.4]), false, false, Some(0)) + .sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[0.4]), + false, + false, + Some(0) + ) .is_ok()); // Without replacement can not sample more than 100%. assert!(df - .sample_frac(&Series::new("frac", &[2.0]), false, false, Some(0)) + .sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[2.0]), + false, + false, + Some(0) + ) .is_err()); assert!(df - .sample_n(&Series::new("s", &[3]), true, false, Some(0)) + .sample_n( + &Series::new(PlSmallStr::from_static("s"), &[3]), + true, + false, + Some(0) + ) .is_ok()); assert!(df - .sample_frac(&Series::new("frac", &[0.4]), true, false, Some(0)) + .sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[0.4]), + true, + false, + Some(0) + ) .is_ok()); // With replacement can sample more than 100%. assert!(df - .sample_frac(&Series::new("frac", &[2.0]), true, false, Some(0)) + .sample_frac( + &Series::new(PlSmallStr::from_static("frac"), &[2.0]), + true, + false, + Some(0) + ) .is_ok()); } } diff --git a/crates/polars-core/src/chunked_array/struct_/frame.rs b/crates/polars-core/src/chunked_array/struct_/frame.rs index 28149aba4ca8..280a9df6da56 100644 --- a/crates/polars-core/src/chunked_array/struct_/frame.rs +++ b/crates/polars-core/src/chunked_array/struct_/frame.rs @@ -1,8 +1,10 @@ +use polars_utils::pl_str::PlSmallStr; + use crate::frame::DataFrame; use crate::prelude::StructChunked; impl DataFrame { - pub fn into_struct(self, name: &str) -> StructChunked { + pub fn into_struct(self, name: PlSmallStr) -> StructChunked { StructChunked::from_series(name, &self.columns).expect("same invariants") } } diff --git a/crates/polars-core/src/chunked_array/struct_/mod.rs b/crates/polars-core/src/chunked_array/struct_/mod.rs index 9db9a6a7ce35..882251a43d6d 100644 --- a/crates/polars-core/src/chunked_array/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/struct_/mod.rs @@ -5,9 +5,9 @@ use std::fmt::Write; use arrow::array::StructArray; use arrow::bitmap::Bitmap; use arrow::compute::utils::combine_validities_and; -use arrow::legacy::utils::CustomIterTools; use polars_error::{polars_ensure, PolarsResult}; use polars_utils::aliases::PlHashMap; +use polars_utils::itertools::Itertools; use crate::chunked_array::cast::CastOptions; use crate::chunked_array::ChunkedArray; @@ -18,7 +18,7 @@ use crate::utils::Container; pub type StructChunked = ChunkedArray; -fn constructor(name: &str, fields: &[Series]) -> PolarsResult { +fn constructor(name: PlSmallStr, fields: &[Series]) -> PolarsResult { // Different chunk lengths: rechunk and recurse. if !fields.iter().map(|s| s.n_chunks()).all_equal() { let fields = fields.iter().map(|s| s.rechunk()).collect::>(); @@ -62,7 +62,7 @@ fn constructor(name: &str, fields: &[Series]) -> PolarsResult { } impl StructChunked { - pub fn from_series(name: &str, fields: &[Series]) -> PolarsResult { + pub fn from_series(name: PlSmallStr, fields: &[Series]) -> PolarsResult { let mut names = PlHashSet::with_capacity(fields.len()); let first_len = fields.first().map(|s| s.len()).unwrap_or(0); let mut max_len = first_len; @@ -110,7 +110,7 @@ impl StructChunked { } constructor(name, &new_fields) } else if fields.is_empty() { - let fields = &[Series::new_null("", 0)]; + let fields = &[Series::new_null(PlSmallStr::EMPTY, 0)]; constructor(name, fields) } else { constructor(name, fields) @@ -136,7 +136,11 @@ impl StructChunked { // SAFETY: correct type. unsafe { - Series::from_chunks_and_dtype_unchecked(&field.name, field_chunks, &field.dtype) + Series::from_chunks_and_dtype_unchecked( + field.name.clone(), + field_chunks, + &field.dtype, + ) } }) .collect() @@ -155,7 +159,7 @@ impl StructChunked { let struct_len = self.len(); let new_fields = dtype_fields .iter() - .map(|new_field| match map.get(new_field.name().as_str()) { + .map(|new_field| match map.get(new_field.name()) { Some(s) => { if unchecked { s.cast_unchecked(&new_field.dtype) @@ -164,14 +168,14 @@ impl StructChunked { } }, None => Ok(Series::full_null( - new_field.name(), + new_field.name().clone(), struct_len, &new_field.dtype, )), }) .collect::>>()?; - let mut out = Self::from_series(self.name(), &new_fields)?; + let mut out = Self::from_series(self.name().clone(), &new_fields)?; if self.null_count > 0 { out.zip_outer_validity(self); } @@ -213,7 +217,7 @@ impl StructChunked { scratch.clear(); } let array = builder.freeze().boxed(); - Series::try_from((ca.name(), array)) + Series::try_from((ca.name().clone(), array)) }, _ => { let fields = self @@ -227,7 +231,7 @@ impl StructChunked { } }) .collect::>>()?; - let mut out = Self::from_series(self.name(), &fields)?; + let mut out = Self::from_series(self.name().clone(), &fields)?; if self.null_count > 0 { out.zip_outer_validity(self); } @@ -272,7 +276,7 @@ impl StructChunked { .iter() .map(func) .collect::>>()?; - Self::from_series(self.name(), &fields).map(|mut ca| { + Self::from_series(self.name().clone(), &fields).map(|mut ca| { if self.null_count > 0 { // SAFETY: we don't change types/ lengths. unsafe { @@ -293,7 +297,7 @@ impl StructChunked { pub fn get_row_encoded(&self, options: SortOptions) -> PolarsResult { let s = self.clone().into_series(); _get_rows_encoded_ca( - self.name(), + self.name().clone(), &[s], &[options.descending], &[options.nulls_last], @@ -350,7 +354,7 @@ impl StructChunked { pub fn field_by_name(&self, name: &str) -> PolarsResult { self.fields_as_series() .into_iter() - .find(|s| s.name() == name) + .find(|s| s.name().as_str() == name) .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name)) } pub(crate) fn set_outer_validity(&mut self, validity: Option) { diff --git a/crates/polars-core/src/chunked_array/temporal/date.rs b/crates/polars-core/src/chunked_array/temporal/date.rs index 26e52d7d2f0f..ea0bb11d10fc 100644 --- a/crates/polars-core/src/chunked_array/temporal/date.rs +++ b/crates/polars-core/src/chunked_array/temporal/date.rs @@ -25,7 +25,7 @@ impl DateChunked { } /// Construct a new [`DateChunked`] from an iterator over [`NaiveDate`]. - pub fn from_naive_date>(name: &str, v: I) -> Self { + pub fn from_naive_date>(name: PlSmallStr, v: I) -> Self { let unit = v.into_iter().map(naive_date_to_date).collect::>(); Int32Chunked::from_vec(name, unit).into() } @@ -51,7 +51,7 @@ impl DateChunked { /// Construct a new [`DateChunked`] from an iterator over optional [`NaiveDate`]. pub fn from_naive_date_options>>( - name: &str, + name: PlSmallStr, v: I, ) -> Self { let unit = v.into_iter().map(|opt| opt.map(naive_date_to_date)); diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs index 838bc2cd9527..92439e5b7527 100644 --- a/crates/polars-core/src/chunked_array/temporal/datetime.rs +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -72,7 +72,7 @@ impl DatetimeChunked { )? }, }; - ca.rename(self.name()); + ca.rename(self.name().clone()); Ok(ca) } @@ -86,7 +86,7 @@ impl DatetimeChunked { /// Construct a new [`DatetimeChunked`] from an iterator over [`NaiveDateTime`]. pub fn from_naive_datetime>( - name: &str, + name: PlSmallStr, v: I, tu: TimeUnit, ) -> Self { @@ -100,7 +100,7 @@ impl DatetimeChunked { } pub fn from_naive_datetime_options>>( - name: &str, + name: PlSmallStr, v: I, tu: TimeUnit, ) -> Self { @@ -205,7 +205,7 @@ mod test { // NOTE: the values are checked and correct. let dt = DatetimeChunked::from_naive_datetime( - "name", + PlSmallStr::from_static("name"), datetimes.iter().copied(), TimeUnit::Nanoseconds, ); diff --git a/crates/polars-core/src/chunked_array/temporal/duration.rs b/crates/polars-core/src/chunked_array/temporal/duration.rs index 7c649e3178b0..df8a51388baf 100644 --- a/crates/polars-core/src/chunked_array/temporal/duration.rs +++ b/crates/polars-core/src/chunked_array/temporal/duration.rs @@ -62,7 +62,7 @@ impl DurationChunked { /// Construct a new [`DurationChunked`] from an iterator over [`ChronoDuration`]. pub fn from_duration>( - name: &str, + name: PlSmallStr, v: I, tu: TimeUnit, ) -> Self { @@ -77,7 +77,7 @@ impl DurationChunked { /// Construct a new [`DurationChunked`] from an iterator over optional [`ChronoDuration`]. pub fn from_duration_options>>( - name: &str, + name: PlSmallStr, v: I, tu: TimeUnit, ) -> Self { diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index d9f50fe9ad96..e3ab1c01c164 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -17,6 +17,8 @@ use chrono::NaiveTime; use chrono_tz::Tz; #[cfg(feature = "timezones")] use once_cell::sync::Lazy; +#[cfg(feature = "timezones")] +use polars_utils::pl_str::PlSmallStr; #[cfg(all(feature = "regex", feature = "timezones"))] use regex::Regex; #[cfg(feature = "dtype-time")] @@ -68,14 +70,16 @@ pub fn parse_time_zone(tz: &str) -> PolarsResult { /// > In the "Etc" area, zones west of GMT have a positive sign and those east /// > have a negative sign in their name (e.g "Etc/GMT-14" is 14 hours ahead of GMT). #[cfg(feature = "timezones")] -pub fn parse_fixed_offset(tz: &str) -> PolarsResult { +pub fn parse_fixed_offset(tz: &str) -> PolarsResult { + use polars_utils::format_pl_smallstr; + if let Some(caps) = FIXED_OFFSET_RE.captures(tz) { let sign = match caps.name("sign").map(|s| s.as_str()) { Some("-") => "+", _ => "-", }; let hour = caps.name("hour").unwrap().as_str().parse::().unwrap(); - let etc_tz = format!("Etc/GMT{}{}", sign, hour); + let etc_tz = format_pl_smallstr!("Etc/GMT{}{}", sign, hour); if etc_tz.parse::().is_ok() { return Ok(etc_tz); } diff --git a/crates/polars-core/src/chunked_array/temporal/time.rs b/crates/polars-core/src/chunked_array/temporal/time.rs index 3627189052a5..77e204c765de 100644 --- a/crates/polars-core/src/chunked_array/temporal/time.rs +++ b/crates/polars-core/src/chunked_array/temporal/time.rs @@ -40,7 +40,7 @@ impl TimeChunked { mutarr.freeze().boxed() }); - ca.rename(self.name()); + ca.rename(self.name().clone()); ca } @@ -65,7 +65,7 @@ impl TimeChunked { } /// Construct a new [`TimeChunked`] from an iterator over [`NaiveTime`]. - pub fn from_naive_time>(name: &str, v: I) -> Self { + pub fn from_naive_time>(name: PlSmallStr, v: I) -> Self { let vals = v .into_iter() .map(|nt| time_to_time64ns(&nt)) @@ -75,7 +75,7 @@ impl TimeChunked { /// Construct a new [`TimeChunked`] from an iterator over optional [`NaiveTime`]. pub fn from_naive_time_options>>( - name: &str, + name: PlSmallStr, v: I, ) -> Self { let vals = v.into_iter().map(|opt| opt.map(|nt| time_to_time64ns(&nt))); diff --git a/crates/polars-core/src/chunked_array/trusted_len.rs b/crates/polars-core/src/chunked_array/trusted_len.rs index 84ff13cb906d..2304d74933b0 100644 --- a/crates/polars-core/src/chunked_array/trusted_len.rs +++ b/crates/polars-core/src/chunked_array/trusted_len.rs @@ -168,7 +168,7 @@ where { fn from_iter_trusted_length>(iter: I) -> Self { let arr = BinaryArray::from_iter_values(iter.into_iter()); - ChunkedArray::with_chunk("", arr) + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) } } @@ -179,7 +179,7 @@ where fn from_iter_trusted_length>>(iter: I) -> Self { let iter = iter.into_iter(); let arr = BinaryArray::from_iter(iter); - ChunkedArray::with_chunk("", arr) + ChunkedArray::with_chunk(PlSmallStr::EMPTY, arr) } } diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index ee5839663ddf..e9d961ef4be0 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -4,6 +4,7 @@ //! We could use [serde_1712](https://github.com/serde-rs/serde/issues/1712), but that gave problems caused by //! [rust_96956](https://github.com/rust-lang/rust/issues/96956), so we make a dummy type without static +#[cfg(feature = "dtype-categorical")] use serde::de::SeqAccess; use serde::{Deserialize, Serialize}; @@ -191,7 +192,7 @@ impl From for DataType { #[cfg(feature = "dtype-categorical")] Categorical(_, ordering) => Self::Categorical(None, ordering), #[cfg(feature = "dtype-categorical")] - Enum(Some(categories), _) => create_enum_data_type(categories.0), + Enum(Some(categories), _) => create_enum_dtype(categories.0), #[cfg(feature = "dtype-categorical")] Enum(None, ordering) => Self::Enum(None, ordering), #[cfg(feature = "dtype-decimal")] diff --git a/crates/polars-core/src/datatypes/aliases.rs b/crates/polars-core/src/datatypes/aliases.rs index 263598de0140..4787b7fcd229 100644 --- a/crates/polars-core/src/datatypes/aliases.rs +++ b/crates/polars-core/src/datatypes/aliases.rs @@ -1,5 +1,7 @@ pub use arrow::legacy::index::IdxArr; -pub use polars_utils::aliases::{InitHashMaps, PlHashMap, PlHashSet, PlIndexMap, PlIndexSet}; +pub use polars_utils::aliases::{ + InitHashMaps, PlHashMap, PlHashSet, PlIndexMap, PlIndexSet, PlRandomState, +}; use super::*; use crate::hashing::IdBuildHasher; @@ -19,7 +21,7 @@ pub type IdxType = UInt32Type; #[cfg(feature = "bigidx")] pub type IdxType = UInt64Type; -pub use smartstring::alias::String as SmartString; +pub use polars_utils::pl_str::PlSmallStr; /// This hashmap uses an IdHasher pub type PlIdHashMap = hashbrown::HashMap; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index e3db5b5cdba7..c3f5f57e0c68 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -1,7 +1,8 @@ #[cfg(feature = "dtype-struct")] use arrow::legacy::trusted_len::TrustedLenPush; use arrow::types::PrimitiveType; -use polars_utils::format_smartstring; +use polars_utils::format_pl_smallstr; +use polars_utils::itertools::Itertools; #[cfg(feature = "dtype-struct")] use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "dtype-categorical")] @@ -66,9 +67,9 @@ pub enum AnyValue<'a> { /// A 64-bit time representing the elapsed time since midnight in nanoseconds #[cfg(feature = "dtype-time")] Time(i64), - #[cfg(feature = "dtype-categorical")] // If syncptr is_null the data is in the rev-map // otherwise it is in the array pointer + #[cfg(feature = "dtype-categorical")] Categorical(u32, &'a RevMapping, SyncPtr), #[cfg(feature = "dtype-categorical")] Enum(u32, &'a RevMapping, SyncPtr), @@ -76,22 +77,21 @@ pub enum AnyValue<'a> { List(Series), #[cfg(feature = "dtype-array")] Array(Series, usize), - #[cfg(feature = "object")] /// Can be used to fmt and implements Any, so can be downcasted to the proper value type. #[cfg(feature = "object")] Object(&'a dyn PolarsObjectSafe), #[cfg(feature = "object")] ObjectOwned(OwnedObject), - #[cfg(feature = "dtype-struct")] // 3 pointers and thus not larger than string/vec // - idx in the `&StructArray` // - The array itself // - The fields + #[cfg(feature = "dtype-struct")] Struct(usize, &'a StructArray, &'a [Field]), #[cfg(feature = "dtype-struct")] StructOwned(Box<(Vec>, Vec)>), /// An UTF8 encoded string type. - StringOwned(smartstring::alias::String), + StringOwned(PlSmallStr), Binary(&'a [u8]), BinaryOwned(Vec), /// A 128-bit fixed point decimal number with a scale. @@ -325,8 +325,8 @@ impl<'a> Deserialize<'a> for AnyValue<'static> { AnyValue::List(value) }, (AvField::StringOwned, variant) => { - let value: String = variant.newtype_variant()?; - AnyValue::StringOwned(value.into()) + let value: PlSmallStr = variant.newtype_variant()?; + AnyValue::StringOwned(value) }, (AvField::BinaryOwned, variant) => { let value = variant.newtype_variant()?; @@ -341,17 +341,23 @@ impl<'a> Deserialize<'a> for AnyValue<'static> { } impl AnyValue<'static> { - pub fn zero(dtype: &DataType) -> Self { + pub fn zero_sum(dtype: &DataType) -> Self { match dtype { - DataType::String => AnyValue::StringOwned("".into()), - DataType::Boolean => AnyValue::Boolean(false), - // SAFETY: - // Numeric values are static, inform the compiler of this. + DataType::String => AnyValue::StringOwned(PlSmallStr::EMPTY), + DataType::Binary => AnyValue::BinaryOwned(Vec::new()), + DataType::Boolean => (0 as IdxSize).into(), + // SAFETY: numeric values are static, inform the compiler of this. d if d.is_numeric() => unsafe { std::mem::transmute::, AnyValue<'static>>( AnyValue::UInt8(0).cast(dtype), ) }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(unit) => AnyValue::Duration(0, *unit), + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_p, s) => { + AnyValue::Decimal(0, s.expect("unknown scale during execution")) + }, _ => AnyValue::Null, } } @@ -448,7 +454,7 @@ impl<'a> AnyValue<'a> { NumCast::from((*v).parse::().ok()?) } }, - StringOwned(v) => String(v).extract(), + StringOwned(v) => String(v.as_str()).extract(), _ => None, } } @@ -493,6 +499,14 @@ impl<'a> AnyValue<'a> { ) } + pub fn is_nan(&self) -> bool { + match self { + AnyValue::Float32(f) => f.is_nan(), + AnyValue::Float64(f) => f.is_nan(), + _ => false, + } + } + pub fn is_null(&self) -> bool { matches!(self, AnyValue::Null) } @@ -537,7 +551,13 @@ impl<'a> AnyValue<'a> { // to string (av, DataType::String) => { - AnyValue::StringOwned(format_smartstring!("{}", av.extract::()?)) + if av.is_unsigned_integer() { + AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + } else if av.is_float() { + AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + } else { + AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::()?)) + } }, // to binary @@ -830,7 +850,23 @@ impl<'a> AnyValue<'a> { (UInt64(l), UInt64(r)) => UInt64(l + r), (Float32(l), Float32(r)) => Float32(l + r), (Float64(l), Float64(r)) => Float64(l + r), - _ => todo!(), + #[cfg(feature = "dtype-duration")] + (Duration(l, lu), Duration(r, ru)) => { + if lu != ru { + unimplemented!("adding durations with different units is not supported here"); + } + + Duration(l + r, *lu) + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(l, ls), Decimal(r, rs)) => { + if ls != rs { + unimplemented!("adding decimals with different scales is not supported here"); + } + + Decimal(l + r, *ls) + }, + _ => unimplemented!(), } } @@ -838,7 +874,7 @@ impl<'a> AnyValue<'a> { pub fn as_borrowed(&self) -> AnyValue<'_> { match self { AnyValue::BinaryOwned(data) => AnyValue::Binary(data), - AnyValue::StringOwned(data) => AnyValue::String(data), + AnyValue::StringOwned(data) => AnyValue::String(data.as_str()), av => av.clone(), } } @@ -866,7 +902,7 @@ impl<'a> AnyValue<'a> { #[cfg(feature = "dtype-time")] Time(v) => Time(v), List(v) => List(v), - String(v) => StringOwned(v.into()), + String(v) => StringOwned(PlSmallStr::from_str(v)), StringOwned(v) => StringOwned(v), Binary(v) => BinaryOwned(v.to_vec()), BinaryOwned(v) => BinaryOwned(v), @@ -901,7 +937,7 @@ impl<'a> AnyValue<'a> { pub fn get_str(&self) -> Option<&str> { match self { AnyValue::String(s) => Some(s), - AnyValue::StringOwned(s) => Some(s), + AnyValue::StringOwned(s) => Some(s.as_str()), #[cfg(feature = "dtype-categorical")] AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { let s = if arr.is_null() { @@ -934,6 +970,23 @@ impl AnyValue<'_> { pub fn eq_missing(&self, other: &Self, null_equal: bool) -> bool { use AnyValue::*; match (self, other) { + // Map to borrowed. + (StringOwned(l), r) => AnyValue::String(l.as_str()) == *r, + (BinaryOwned(l), r) => AnyValue::Binary(l.as_slice()) == *r, + #[cfg(feature = "object")] + (ObjectOwned(l), r) => AnyValue::Object(&*l.0) == *r, + (l, StringOwned(r)) => *l == AnyValue::String(r.as_str()), + (l, BinaryOwned(r)) => *l == AnyValue::Binary(r.as_slice()), + #[cfg(feature = "object")] + (l, ObjectOwned(r)) => *l == AnyValue::Object(&*r.0), + + // Comparison with null. + (Null, Null) => null_equal, + (Null, _) => false, + (_, Null) => false, + + // Equality between equal types. + (Boolean(l), Boolean(r)) => *l == *r, (UInt8(l), UInt8(r)) => *l == *r, (UInt16(l), UInt16(r)) => *l == *r, (UInt32(l), UInt32(r)) => *l == *r, @@ -945,15 +998,7 @@ impl AnyValue<'_> { (Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(), (Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(), (String(l), String(r)) => l == r, - (String(l), StringOwned(r)) => l == r, - (StringOwned(l), String(r)) => l == r, - (StringOwned(l), StringOwned(r)) => l == r, - (Boolean(l), Boolean(r)) => *l == *r, (Binary(l), Binary(r)) => l == r, - (BinaryOwned(l), BinaryOwned(r)) => l == r, - (Binary(l), BinaryOwned(r)) => l == r, - (BinaryOwned(l), Binary(r)) => l == r, - (Null, Null) => null_equal, #[cfg(feature = "dtype-time")] (Time(l), Time(r)) => *l == *r, #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] @@ -964,47 +1009,81 @@ impl AnyValue<'_> { }, (List(l), List(r)) => l == r, #[cfg(feature = "dtype-categorical")] - (Categorical(idx_l, rev_l, _), Categorical(idx_r, rev_r, _)) => match (rev_l, rev_r) { - (RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => { - id_l == id_r && idx_l == idx_r - }, - (RevMapping::Local(_, id_l), RevMapping::Local(_, id_r)) => { - id_l == id_r && idx_l == idx_r - }, - _ => false, + (Categorical(idx_l, rev_l, ptr_l), Categorical(idx_r, rev_r, ptr_r)) => { + if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) { + // We can't support this because our Hash impl directly hashes the index. If you + // add support for this we must change the Hash impl. + unimplemented!( + "comparing categoricals with different revmaps is not supported" + ); + } + + idx_l == idx_r }, #[cfg(feature = "dtype-categorical")] - (Enum(idx_l, _, _), Enum(idx_r, _, _)) => idx_l == idx_r, + (Enum(idx_l, rev_l, ptr_l), Enum(idx_r, rev_r, ptr_r)) => { + // We can't support this because our Hash impl directly hashes the index. If you + // add support for this we must change the Hash impl. + if !same_revmap(rev_l, *ptr_l, rev_r, *ptr_r) { + unimplemented!("comparing enums with different revmaps is not supported"); + } + + idx_l == idx_r + }, #[cfg(feature = "dtype-duration")] (Duration(l, tu_l), Duration(r, tu_r)) => l == r && tu_l == tu_r, #[cfg(feature = "dtype-struct")] (StructOwned(l), StructOwned(r)) => { - let l = &*l.0; - let r = &*r.0; - l == r + let l_av = &*l.0; + let r_av = &*r.0; + l_av == r_av }, - // TODO! add structowned with idx and arced structarray #[cfg(feature = "dtype-struct")] (StructOwned(l), Struct(idx, arr, fields)) => { - let fields_left = &*l.0; - let avs = struct_to_avs_static(*idx, arr, fields); - fields_left == avs + l.0.iter() + .eq_by_(struct_av_iter(*idx, arr, fields), |lv, rv| *lv == rv) }, #[cfg(feature = "dtype-struct")] (Struct(idx, arr, fields), StructOwned(r)) => { - let fields_right = &*r.0; - let avs = struct_to_avs_static(*idx, arr, fields); - fields_right == avs + struct_av_iter(*idx, arr, fields).eq_by_(r.0.iter(), |lv, rv| lv == *rv) + }, + #[cfg(feature = "dtype-struct")] + (Struct(l_idx, l_arr, l_fields), Struct(r_idx, r_arr, r_fields)) => { + struct_av_iter(*l_idx, l_arr, l_fields).eq(struct_av_iter(*r_idx, r_arr, r_fields)) }, #[cfg(feature = "dtype-decimal")] - (Decimal(v_l, scale_l), Decimal(v_r, scale_r)) => { - // Decimal equality here requires that both value and scale be equal, eg - // 1.2 at scale 1, and 1.20 at scale 2, are not equal. - *v_l == *v_r && *scale_l == *scale_r + (Decimal(l_v, l_s), Decimal(r_v, r_s)) => { + // l_v / 10**l_s == r_v / 10**r_s + if l_s == r_s && l_v == r_v || *l_v == 0 && *r_v == 0 { + true + } else if l_s < r_s { + // l_v * 10**(r_s - l_s) == r_v + if let Some(lhs) = (|| { + let exp = i128::checked_pow(10, (r_s - l_s).try_into().ok()?)?; + l_v.checked_mul(exp) + })() { + lhs == *r_v + } else { + false + } + } else { + // l_v == r_v * 10**(l_s - r_s) + if let Some(rhs) = (|| { + let exp = i128::checked_pow(10, (l_s - r_s).try_into().ok()?)?; + r_v.checked_mul(exp) + })() { + *l_v == rhs + } else { + false + } + } }, #[cfg(feature = "object")] (Object(l), Object(r)) => l == r, - _ => false, + + (_, _) => { + unimplemented!("ordering for mixed dtypes is not supported") + }, } } } @@ -1021,6 +1100,23 @@ impl PartialOrd for AnyValue<'_> { fn partial_cmp(&self, other: &Self) -> Option { use AnyValue::*; match (self, &other) { + // Map to borrowed. + (StringOwned(l), r) => AnyValue::String(l.as_str()).partial_cmp(r), + (BinaryOwned(l), r) => AnyValue::Binary(l.as_slice()).partial_cmp(r), + #[cfg(feature = "object")] + (ObjectOwned(l), r) => AnyValue::Object(&*l.0).partial_cmp(r), + (l, StringOwned(r)) => l.partial_cmp(&AnyValue::String(r.as_str())), + (l, BinaryOwned(r)) => l.partial_cmp(&AnyValue::Binary(r.as_slice())), + #[cfg(feature = "object")] + (l, ObjectOwned(r)) => l.partial_cmp(&AnyValue::Object(&*r.0)), + + // Comparison with null. + (Null, Null) => Some(Ordering::Equal), + (Null, _) => Some(Ordering::Less), + (_, Null) => Some(Ordering::Greater), + + // Comparison between equal types. + (Boolean(l), Boolean(r)) => l.partial_cmp(r), (UInt8(l), UInt8(r)) => l.partial_cmp(r), (UInt16(l), UInt16(r)) => l.partial_cmp(r), (UInt32(l), UInt32(r)) => l.partial_cmp(r), @@ -1029,12 +1125,90 @@ impl PartialOrd for AnyValue<'_> { (Int16(l), Int16(r)) => l.partial_cmp(r), (Int32(l), Int32(r)) => l.partial_cmp(r), (Int64(l), Int64(r)) => l.partial_cmp(r), - (Float32(l), Float32(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), - (Float64(l), Float64(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), - _ => match (self.as_borrowed(), other.as_borrowed()) { - (String(l), String(r)) => l.partial_cmp(r), - (Binary(l), Binary(r)) => l.partial_cmp(r), - _ => None, + (Float32(l), Float32(r)) => Some(l.tot_cmp(r)), + (Float64(l), Float64(r)) => Some(l.tot_cmp(r)), + (String(l), String(r)) => l.partial_cmp(r), + (Binary(l), Binary(r)) => l.partial_cmp(r), + #[cfg(feature = "dtype-date")] + (Date(l), Date(r)) => l.partial_cmp(r), + #[cfg(feature = "dtype-datetime")] + (Datetime(lt, lu, lz), Datetime(rt, ru, rz)) => { + if lu != ru || lz != rz { + unimplemented!( + "comparing datetimes with different units or timezones is not supported" + ); + } + + lt.partial_cmp(rt) + }, + #[cfg(feature = "dtype-duration")] + (Duration(lt, lu), Duration(rt, ru)) => { + if lu != ru { + unimplemented!("comparing durations with different units is not supported"); + } + + lt.partial_cmp(rt) + }, + #[cfg(feature = "dtype-time")] + (Time(l), Time(r)) => l.partial_cmp(r), + #[cfg(feature = "dtype-categorical")] + (Categorical(..), Categorical(..)) => { + unimplemented!( + "can't order categoricals as AnyValues, dtype for ordering is needed" + ) + }, + #[cfg(feature = "dtype-categorical")] + (Enum(..), Enum(..)) => { + unimplemented!("can't order enums as AnyValues, dtype for ordering is needed") + }, + (List(_), List(_)) => { + unimplemented!("ordering for List dtype is not supported") + }, + #[cfg(feature = "dtype-array")] + (Array(..), Array(..)) => { + unimplemented!("ordering for Array dtype is not supported") + }, + #[cfg(feature = "object")] + (Object(_), Object(_)) => { + unimplemented!("ordering for Object dtype is not supported") + }, + #[cfg(feature = "dtype-struct")] + (StructOwned(_), StructOwned(_)) + | (StructOwned(_), Struct(..)) + | (Struct(..), StructOwned(_)) + | (Struct(..), Struct(..)) => { + unimplemented!("ordering for Struct dtype is not supported") + }, + #[cfg(feature = "dtype-decimal")] + (Decimal(l_v, l_s), Decimal(r_v, r_s)) => { + // l_v / 10**l_s <=> r_v / 10**r_s + if l_s == r_s && l_v == r_v || *l_v == 0 && *r_v == 0 { + Some(Ordering::Equal) + } else if l_s < r_s { + // l_v * 10**(r_s - l_s) <=> r_v + if let Some(lhs) = (|| { + let exp = i128::checked_pow(10, (r_s - l_s).try_into().ok()?)?; + l_v.checked_mul(exp) + })() { + lhs.partial_cmp(r_v) + } else { + Some(Ordering::Greater) + } + } else { + // l_v <=> r_v * 10**(l_s - r_s) + if let Some(rhs) = (|| { + let exp = i128::checked_pow(10, (l_s - r_s).try_into().ok()?)?; + r_v.checked_mul(exp) + })() { + l_v.partial_cmp(&rhs) + } else { + Some(Ordering::Less) + } + } + }, + + (_, _) => { + unimplemented!("ordering for mixed dtypes is not supported") }, } } @@ -1063,6 +1237,38 @@ fn struct_to_avs_static(idx: usize, arr: &StructArray, fields: &[Field]) -> Vec< avs } +#[cfg(feature = "dtype-categorical")] +fn same_revmap( + rev_l: &RevMapping, + ptr_l: SyncPtr, + rev_r: &RevMapping, + ptr_r: SyncPtr, +) -> bool { + if ptr_l.is_null() && ptr_r.is_null() { + match (rev_l, rev_r) { + (RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => id_l == id_r, + (RevMapping::Local(_, id_l), RevMapping::Local(_, id_r)) => id_l == id_r, + _ => false, + } + } else { + ptr_l == ptr_r + } +} + +#[cfg(feature = "dtype-struct")] +fn struct_av_iter<'a>( + idx: usize, + arr: &'a StructArray, + fields: &'a [Field], +) -> impl Iterator> { + let arrs = arr.values(); + (0..arrs.len()).map(move |i| unsafe { + let arr = &**arrs.get_unchecked_release(i); + let field = fields.get_unchecked_release(i); + arr_to_any_value(arr, idx, &field.dtype) + }) +} + pub trait GetAnyValue { /// # Safety /// @@ -1073,7 +1279,7 @@ pub trait GetAnyValue { impl GetAnyValue for ArrayRef { // Should only be called with physical types unsafe fn get_unchecked(&self, index: usize) -> AnyValue { - match self.data_type() { + match self.dtype() { ArrowDataType::Int8 => { let arr = self .as_any() @@ -1296,7 +1502,7 @@ mod test { DataType::Datetime(TimeUnit::Milliseconds, None), ), ( - ArrowDataType::Timestamp(ArrowTimeUnit::Second, Some("".to_string())), + ArrowDataType::Timestamp(ArrowTimeUnit::Second, Some(PlSmallStr::EMPTY)), DataType::Datetime(TimeUnit::Milliseconds, None), ), (ArrowDataType::LargeUtf8, DataType::String), @@ -1331,7 +1537,7 @@ mod test { (ArrowDataType::Time32(ArrowTimeUnit::Second), DataType::Time), ( ArrowDataType::List(Box::new(ArrowField::new( - "item", + PlSmallStr::from_static("item"), ArrowDataType::Float64, true, ))), @@ -1339,7 +1545,7 @@ mod test { ), ( ArrowDataType::LargeList(Box::new(ArrowField::new( - "item", + PlSmallStr::from_static("item"), ArrowDataType::Float64, true, ))), diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index e092b43e7d67..9cdc5620ed07 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -8,7 +8,7 @@ use super::*; use crate::chunked_array::object::registry::ObjectRegistry; use crate::utils::materialize_dyn_int; -pub type TimeZone = String; +pub type TimeZone = PlSmallStr; pub static DTYPE_ENUM_KEY: &str = "POLARS.CATEGORICAL_TYPE"; pub static DTYPE_ENUM_VALUE: &str = "ENUM"; @@ -152,14 +152,13 @@ impl Eq for DataType {} impl DataType { /// Standardize timezones to consistent values. - pub(crate) fn canonical_timezone(tz: &Option) -> Option { + pub(crate) fn canonical_timezone(tz: &Option) -> Option { match tz.as_deref() { - Some("") => None, + Some("") | None => None, #[cfg(feature = "timezones")] - Some("+00:00") | Some("00:00") | Some("utc") => Some("UTC"), - _ => tz.as_deref(), + Some("+00:00") | Some("00:00") | Some("utc") => Some(PlSmallStr::from_static("UTC")), + Some(v) => Some(PlSmallStr::from_str(v)), } - .map(|s| s.to_string()) } pub fn value_within_range(&self, other: AnyValue) -> bool { @@ -262,7 +261,7 @@ impl DataType { Struct(fields) => { let new_fields = fields .iter() - .map(|s| Field::new(s.name(), s.data_type().to_physical())) + .map(|s| Field::new(s.name().clone(), s.dtype().to_physical())) .collect(); Struct(new_fields) }, @@ -338,6 +337,10 @@ impl DataType { matches!(self, DataType::Binary) } + pub fn is_date(&self) -> bool { + matches!(self, DataType::Date) + } + pub fn is_object(&self) -> bool { #[cfg(feature = "object")] { @@ -498,7 +501,7 @@ impl DataType { } /// Convert to an Arrow Field - pub fn to_arrow_field(&self, name: &str, compat_level: CompatLevel) -> ArrowField { + pub fn to_arrow_field(&self, name: PlSmallStr, compat_level: CompatLevel) -> ArrowField { let metadata = match self { #[cfg(feature = "dtype-categorical")] DataType::Enum(_, _) => Some(BTreeMap::from([( @@ -506,8 +509,8 @@ impl DataType { DTYPE_ENUM_VALUE.into(), )])), DataType::BinaryOffset => Some(BTreeMap::from([( - "pl".to_string(), - "maintain_type".to_string(), + PlSmallStr::from_static("pl"), + PlSmallStr::from_static("maintain_type"), )])), _ => None, }; @@ -574,11 +577,11 @@ impl DataType { Time => Ok(ArrowDataType::Time64(ArrowTimeUnit::Nanosecond)), #[cfg(feature = "dtype-array")] Array(dt, size) => Ok(ArrowDataType::FixedSizeList( - Box::new(dt.to_arrow_field("item", compat_level)), + Box::new(dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level)), *size, )), List(dt) => Ok(ArrowDataType::LargeList(Box::new( - dt.to_arrow_field("item", compat_level), + dt.to_arrow_field(PlSmallStr::from_static("item"), compat_level), ))), Null => Ok(ArrowDataType::Null), #[cfg(feature = "object")] @@ -788,7 +791,7 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult } #[cfg(feature = "dtype-categorical")] -pub fn create_enum_data_type(categories: Utf8ViewArray) -> DataType { +pub fn create_enum_dtype(categories: Utf8ViewArray) -> DataType { let rev_map = RevMapping::build_local(categories); DataType::Enum(Some(Arc::new(rev_map)), Default::default()) } diff --git a/crates/polars-core/src/datatypes/field.rs b/crates/polars-core/src/datatypes/field.rs index aea148546ef0..f3bc3571505c 100644 --- a/crates/polars-core/src/datatypes/field.rs +++ b/crates/polars-core/src/datatypes/field.rs @@ -1,4 +1,4 @@ -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use super::*; @@ -9,10 +9,16 @@ use super::*; derive(Serialize, Deserialize) )] pub struct Field { - pub name: SmartString, + pub name: PlSmallStr, pub dtype: DataType, } +impl From for (PlSmallStr, DataType) { + fn from(value: Field) -> Self { + (value.name, value.dtype) + } +} + pub type FieldRef = Arc; impl Field { @@ -22,19 +28,12 @@ impl Field { /// /// ```rust /// # use polars_core::prelude::*; - /// let f1 = Field::new("Fruit name", DataType::String); - /// let f2 = Field::new("Lawful", DataType::Boolean); - /// let f2 = Field::new("Departure", DataType::Time); + /// let f1 = Field::new("Fruit name".into(), DataType::String); + /// let f2 = Field::new("Lawful".into(), DataType::Boolean); + /// let f2 = Field::new("Departure".into(), DataType::Time); /// ``` #[inline] - pub fn new(name: &str, dtype: DataType) -> Self { - Field { - name: name.into(), - dtype, - } - } - - pub fn from_owned(name: SmartString, dtype: DataType) -> Self { + pub fn new(name: PlSmallStr, dtype: DataType) -> Self { Field { name, dtype } } @@ -44,12 +43,12 @@ impl Field { /// /// ```rust /// # use polars_core::prelude::*; - /// let f = Field::new("Year", DataType::Int32); + /// let f = Field::new("Year".into(), DataType::Int32); /// /// assert_eq!(f.name(), "Year"); /// ``` #[inline] - pub fn name(&self) -> &SmartString { + pub fn name(&self) -> &PlSmallStr { &self.name } @@ -59,12 +58,12 @@ impl Field { /// /// ```rust /// # use polars_core::prelude::*; - /// let f = Field::new("Birthday", DataType::Date); + /// let f = Field::new("Birthday".into(), DataType::Date); /// - /// assert_eq!(f.data_type(), &DataType::Date); + /// assert_eq!(f.dtype(), &DataType::Date); /// ``` #[inline] - pub fn data_type(&self) -> &DataType { + pub fn dtype(&self) -> &DataType { &self.dtype } @@ -74,10 +73,10 @@ impl Field { /// /// ```rust /// # use polars_core::prelude::*; - /// let mut f = Field::new("Temperature", DataType::Int32); + /// let mut f = Field::new("Temperature".into(), DataType::Int32); /// f.coerce(DataType::Float32); /// - /// assert_eq!(f, Field::new("Temperature", DataType::Float32)); + /// assert_eq!(f, Field::new("Temperature".into(), DataType::Float32)); /// ``` pub fn coerce(&mut self, dtype: DataType) { self.dtype = dtype; @@ -89,12 +88,12 @@ impl Field { /// /// ```rust /// # use polars_core::prelude::*; - /// let mut f = Field::new("Atomic number", DataType::UInt32); + /// let mut f = Field::new("Atomic number".into(), DataType::UInt32); /// f.set_name("Proton".into()); /// - /// assert_eq!(f, Field::new("Proton", DataType::UInt32)); + /// assert_eq!(f, Field::new("Proton".into(), DataType::UInt32)); /// ``` - pub fn set_name(&mut self, name: SmartString) { + pub fn set_name(&mut self, name: PlSmallStr) { self.name = name; } @@ -104,13 +103,13 @@ impl Field { /// /// ```rust /// # use polars_core::prelude::*; - /// let f = Field::new("Value", DataType::Int64); - /// let af = arrow::datatypes::Field::new("Value", arrow::datatypes::ArrowDataType::Int64, true); + /// let f = Field::new("Value".into(), DataType::Int64); + /// let af = arrow::datatypes::Field::new("Value".into(), arrow::datatypes::ArrowDataType::Int64, true); /// /// assert_eq!(f.to_arrow(CompatLevel::newest()), af); /// ``` pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowField { - self.dtype.to_arrow_field(self.name.as_str(), compat_level) + self.dtype.to_arrow_field(self.name.clone(), compat_level) } } @@ -146,8 +145,8 @@ impl DataType { ArrowDataType::Float32 => DataType::Float32, ArrowDataType::Float64 => DataType::Float64, #[cfg(feature = "dtype-array")] - ArrowDataType::FixedSizeList(f, size) => DataType::Array(DataType::from_arrow(f.data_type(), bin_to_view).boxed(), *size), - ArrowDataType::LargeList(f) | ArrowDataType::List(f) => DataType::List(DataType::from_arrow(f.data_type(), bin_to_view).boxed()), + ArrowDataType::FixedSizeList(f, size) => DataType::Array(DataType::from_arrow(f.dtype(), bin_to_view).boxed(), *size), + ArrowDataType::LargeList(f) | ArrowDataType::List(f) => DataType::List(DataType::from_arrow(f.dtype(), bin_to_view).boxed()), ArrowDataType::Date32 => DataType::Date, ArrowDataType::Timestamp(tu, tz) => DataType::Datetime(tu.into(), DataType::canonical_timezone(tz)), ArrowDataType::Duration(tu) => DataType::Duration(tu.into()), @@ -163,7 +162,7 @@ impl DataType { ArrowDataType::Struct(_) => { panic!("activate the 'dtype-struct' feature to handle struct data types") } - ArrowDataType::Extension(name, _, _) if name == "POLARS_EXTENSION_TYPE" => { + ArrowDataType::Extension(name, _, _) if name.as_str() == "POLARS_EXTENSION_TYPE" => { #[cfg(feature = "object")] { DataType::Object("extension", None) @@ -199,6 +198,6 @@ impl From<&ArrowDataType> for DataType { impl From<&ArrowField> for Field { fn from(f: &ArrowField) -> Self { - Field::new(&f.name, f.data_type().into()) + Field::new(f.name.clone(), f.dtype().into()) } } diff --git a/crates/polars-core/src/datatypes/time_unit.rs b/crates/polars-core/src/datatypes/time_unit.rs index 481de22249b1..d3a9a61443fb 100644 --- a/crates/polars-core/src/datatypes/time_unit.rs +++ b/crates/polars-core/src/datatypes/time_unit.rs @@ -58,7 +58,7 @@ impl TimeUnit { } } -#[cfg(feature = "rows")] +#[cfg(any(feature = "rows", feature = "object"))] #[cfg(any(feature = "dtype-datetime", feature = "dtype-duration"))] #[inline] pub(crate) fn convert_time_units(v: i64, tu_l: TimeUnit, tu_r: TimeUnit) -> i64 { diff --git a/crates/polars-core/src/export.rs b/crates/polars-core/src/export.rs index d8446d2e9173..1733715cc35d 100644 --- a/crates/polars-core/src/export.rs +++ b/crates/polars-core/src/export.rs @@ -4,6 +4,6 @@ pub use chrono; pub use regex; #[cfg(feature = "serde")] pub use serde; -pub use {ahash, arrow, num_traits as num, once_cell, rayon}; +pub use {arrow, num_traits as num, once_cell, rayon}; pub use crate::hashing::_boost_hash_combine; diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index 066536ddec15..00455a1a841a 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -446,16 +446,16 @@ fn field_to_str(f: &Field, str_truncate: usize) -> (String, usize) { if env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) { column_name = "".to_string(); } - let column_data_type = if env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES) { + let column_dtype = if env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES) { "".to_string() } else if env_is_true(FMT_TABLE_INLINE_COLUMN_DATA_TYPE) | env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) { - format!("{}", f.data_type()) + format!("{}", f.dtype()) } else { - format!("\n{}", f.data_type()) + format!("\n{}", f.dtype()) }; - let mut dtype_length = column_data_type.trim_start().len(); + let mut dtype_length = column_dtype.trim_start().len(); let mut separator = "\n---"; if env_is_true(FMT_TABLE_HIDE_COLUMN_SEPARATOR) | env_is_true(FMT_TABLE_HIDE_COLUMN_NAMES) @@ -466,11 +466,11 @@ fn field_to_str(f: &Field, str_truncate: usize) -> (String, usize) { let s = if env_is_true(FMT_TABLE_INLINE_COLUMN_DATA_TYPE) & !env_is_true(FMT_TABLE_HIDE_COLUMN_DATA_TYPES) { - let inline_name_dtype = format!("{column_name} ({column_data_type})"); + let inline_name_dtype = format!("{column_name} ({column_dtype})"); dtype_length = inline_name_dtype.len(); inline_name_dtype } else { - format!("{column_name}{separator}{column_data_type}") + format!("{column_name}{separator}{column_dtype}") }; let mut s_len = std::cmp::max(name_length, dtype_length); let separator_length = separator.trim().len(); @@ -729,7 +729,7 @@ impl Display for DataFrame { let num_preset = std::env::var(FMT_TABLE_CELL_NUMERIC_ALIGNMENT) .unwrap_or_else(|_| str_preset.to_string()); for (column_index, column) in table.column_iter_mut().enumerate() { - let dtype = fields[column_index].data_type(); + let dtype = fields[column_index].dtype(); let mut preset = str_preset.as_str(); if dtype.is_numeric() || dtype.is_decimal() { preset = num_preset.as_str(); @@ -1190,8 +1190,12 @@ mod test { #[test] fn test_fmt_list() { - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); builder.append_opt_slice(Some(&[1, 2, 3, 4, 5, 6])); builder.append_opt_slice(None); let list_long = builder.finish().into_series(); @@ -1266,8 +1270,12 @@ Series: 'a' [list[i32]] format!("{:?}", list_long) ); - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); builder.append_opt_slice(Some(&[1])); builder.append_opt_slice(None); let list_short = builder.finish().into_series(); @@ -1308,8 +1316,12 @@ Series: 'a' [list[i32]] format!("{:?}", list_short) ); - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); + let mut builder = ListPrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("a"), + 10, + 10, + DataType::Int32, + ); builder.append_opt_slice(Some(&[])); builder.append_opt_slice(None); let list_empty = builder.finish().into_series(); @@ -1329,7 +1341,8 @@ Series: 'a' [list[i32]] #[test] fn test_fmt_temporal() { - let s = Int32Chunked::new("Date", &[Some(1), None, Some(3)]).into_date(); + let s = Int32Chunked::new(PlSmallStr::from_static("Date"), &[Some(1), None, Some(3)]) + .into_date(); assert_eq!( r#"shape: (3,) Series: 'Date' [date] @@ -1341,7 +1354,7 @@ Series: 'Date' [date] format!("{:?}", s.into_series()) ); - let s = Int64Chunked::new("", &[Some(1), None, Some(1_000_000_000_000)]) + let s = Int64Chunked::new(PlSmallStr::EMPTY, &[Some(1), None, Some(1_000_000_000_000)]) .into_datetime(TimeUnit::Nanoseconds, None); assert_eq!( r#"shape: (3,) @@ -1357,7 +1370,7 @@ Series: '' [datetime[ns]] #[test] fn test_fmt_chunkedarray() { - let ca = Int32Chunked::new("Date", &[Some(1), None, Some(3)]); + let ca = Int32Chunked::new(PlSmallStr::from_static("Date"), &[Some(1), None, Some(3)]); assert_eq!( r#"shape: (3,) ChunkedArray: 'Date' [i32] @@ -1368,7 +1381,7 @@ ChunkedArray: 'Date' [i32] ]"#, format!("{:?}", ca) ); - let ca = StringChunked::new("name", &["a", "b"]); + let ca = StringChunked::new(PlSmallStr::from_static("name"), &["a", "b"]); assert_eq!( r#"shape: (2,) ChunkedArray: 'name' [str] diff --git a/crates/polars-core/src/frame/arithmetic.rs b/crates/polars-core/src/frame/arithmetic.rs index 887fedfb2d57..69e2279cd47f 100644 --- a/crates/polars-core/src/frame/arithmetic.rs +++ b/crates/polars-core/src/frame/arithmetic.rs @@ -151,7 +151,7 @@ impl DataFrame { // trick to fill a series with nulls let vals: &[Option] = &[None]; - let s = Series::new(name, vals).cast(dtype)?; + let s = Series::new(name.clone(), vals).cast(dtype)?; cols.push(s.new_from_index(0, max_len)) } } diff --git a/crates/polars-core/src/frame/chunks.rs b/crates/polars-core/src/frame/chunks.rs index 2a371e85ab31..349a77c56d75 100644 --- a/crates/polars-core/src/frame/chunks.rs +++ b/crates/polars-core/src/frame/chunks.rs @@ -5,15 +5,15 @@ use crate::prelude::*; use crate::utils::_split_offsets; use crate::POOL; -impl TryFrom<(RecordBatch, &[ArrowField])> for DataFrame { +impl TryFrom<(RecordBatch, &ArrowSchema)> for DataFrame { type Error = PolarsError; - fn try_from(arg: (RecordBatch, &[ArrowField])) -> PolarsResult { + fn try_from(arg: (RecordBatch, &ArrowSchema)) -> PolarsResult { let columns: PolarsResult> = arg .0 .columns() .iter() - .zip(arg.1) + .zip(arg.1.iter_values()) .map(|(arr, field)| Series::try_from((field, arr.clone()))) .collect(); diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs index d3b28dd8b830..3e597756eb1e 100644 --- a/crates/polars-core/src/frame/explode.rs +++ b/crates/polars-core/src/frame/explode.rs @@ -1,14 +1,12 @@ -use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked; use arrow::offset::OffsetsBuffer; +use polars_utils::pl_str::PlSmallStr; use rayon::prelude::*; -#[cfg(feature = "serde-lazy")] +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use smartstring::alias::String as SmartString; use crate::chunked_array::ops::explode::offsets_to_indexes; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::try_get_supertype; use crate::POOL; fn get_exploded(series: &Series) -> PolarsResult<(Series, OffsetsBuffer)> { @@ -22,16 +20,12 @@ fn get_exploded(series: &Series) -> PolarsResult<(Series, OffsetsBuffer)> { /// Arguments for `[DataFrame::unpivot]` function #[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] -pub struct UnpivotArgs { - pub on: Vec, - pub index: Vec, - pub variable_name: Option, - pub value_name: Option, - /// Whether the unpivot may be done - /// in the streaming engine - /// This will not have a stable ordering - pub streamable: bool, +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnpivotArgsIR { + pub on: Vec, + pub index: Vec, + pub variable_name: Option, + pub value_name: Option, } impl DataFrame { @@ -45,15 +39,19 @@ impl DataFrame { return Ok(df); } columns.sort_by(|sa, sb| { - self.check_name_to_idx(sa.name()) + self.check_name_to_idx(sa.name().as_str()) .expect("checked above") - .partial_cmp(&self.check_name_to_idx(sb.name()).expect("checked above")) + .partial_cmp( + &self + .check_name_to_idx(sb.name().as_str()) + .expect("checked above"), + ) .expect("cmp usize -> Ordering") }); // first remove all the exploded columns for s in &columns { - df = df.drop(s.name())?; + df = df.drop(s.name().as_str())?; } let exploded_columns = POOL.install(|| { @@ -69,7 +67,7 @@ impl DataFrame { exploded: Series, ) -> PolarsResult<()> { if exploded.len() == df.height() || df.width() == 0 { - let col_idx = original_df.check_name_to_idx(exploded.name())?; + let col_idx = original_df.check_name_to_idx(exploded.name().as_str())?; df.columns.insert(col_idx, exploded); } else { polars_bail!( @@ -104,7 +102,7 @@ impl DataFrame { let (exploded, offsets) = &exploded_columns[0]; let row_idx = offsets_to_indexes(offsets.as_slice(), exploded.len()); - let mut row_idx = IdxCa::from_vec("", row_idx); + let mut row_idx = IdxCa::from_vec(PlSmallStr::EMPTY, row_idx); row_idx.set_sorted_flag(IsSorted::Ascending); // SAFETY: @@ -129,13 +127,13 @@ impl DataFrame { /// /// ```ignore /// # use polars_core::prelude::*; - /// let s0 = Series::new("a", &[1i64, 2, 3]); - /// let s1 = Series::new("b", &[1i64, 1, 1]); - /// let s2 = Series::new("c", &[2i64, 2, 2]); + /// let s0 = Series::new("a".into(), &[1i64, 2, 3]); + /// let s1 = Series::new("b".into(), &[1i64, 1, 1]); + /// let s2 = Series::new("c".into(), &[2i64, 2, 2]); /// let list = Series::new("foo", &[s0, s1, s2]); /// - /// let s0 = Series::new("B", [1, 2, 3]); - /// let s1 = Series::new("C", [1, 1, 1]); + /// let s0 = Series::new("B".into(), [1, 2, 3]); + /// let s1 = Series::new("C".into(), [1, 1, 1]); /// let df = DataFrame::new(vec![list, s0, s1])?; /// let exploded = df.explode(["foo"])?; /// @@ -185,182 +183,13 @@ impl DataFrame { pub fn explode(&self, columns: I) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { // We need to sort the column by order of original occurrence. Otherwise the insert by index // below will panic let columns = self.select_series(columns)?; self.explode_impl(columns) } - - /// - /// Unpivot a `DataFrame` from wide to long format. - /// - /// # Example - /// - /// # Arguments - /// - /// * `on` - String slice that represent the columns to use as value variables. - /// * `index` - String slice that represent the columns to use as id variables. - /// - /// If `on` is empty all columns that are not in `index` will be used. - /// - /// ```ignore - /// # use polars_core::prelude::*; - /// let df = df!("A" => &["a", "b", "a"], - /// "B" => &[1, 3, 5], - /// "C" => &[10, 11, 12], - /// "D" => &[2, 4, 6] - /// )?; - /// - /// let unpivoted = df.unpivot(&["A", "B"], &["C", "D"])?; - /// println!("{:?}", df); - /// println!("{:?}", unpivoted); - /// # Ok::<(), PolarsError>(()) - /// ``` - /// Outputs: - /// ```text - /// +-----+-----+-----+-----+ - /// | A | B | C | D | - /// | --- | --- | --- | --- | - /// | str | i32 | i32 | i32 | - /// +=====+=====+=====+=====+ - /// | "a" | 1 | 10 | 2 | - /// +-----+-----+-----+-----+ - /// | "b" | 3 | 11 | 4 | - /// +-----+-----+-----+-----+ - /// | "a" | 5 | 12 | 6 | - /// +-----+-----+-----+-----+ - /// - /// +-----+-----+----------+-------+ - /// | A | B | variable | value | - /// | --- | --- | --- | --- | - /// | str | i32 | str | i32 | - /// +=====+=====+==========+=======+ - /// | "a" | 1 | "C" | 10 | - /// +-----+-----+----------+-------+ - /// | "b" | 3 | "C" | 11 | - /// +-----+-----+----------+-------+ - /// | "a" | 5 | "C" | 12 | - /// +-----+-----+----------+-------+ - /// | "a" | 1 | "D" | 2 | - /// +-----+-----+----------+-------+ - /// | "b" | 3 | "D" | 4 | - /// +-----+-----+----------+-------+ - /// | "a" | 5 | "D" | 6 | - /// +-----+-----+----------+-------+ - /// ``` - pub fn unpivot(&self, on: I, index: J) -> PolarsResult - where - I: IntoVec, - J: IntoVec, - { - let index = index.into_vec(); - let on = on.into_vec(); - self.unpivot2(UnpivotArgs { - on, - index, - ..Default::default() - }) - } - - /// Similar to unpivot, but without generics. This may be easier if you want to pass - /// an empty `index` or empty `on`. - pub fn unpivot2(&self, args: UnpivotArgs) -> PolarsResult { - let index = args.index; - let mut on = args.on; - - let variable_name = args.variable_name.as_deref().unwrap_or("variable"); - let value_name = args.value_name.as_deref().unwrap_or("value"); - - let len = self.height(); - - // if value vars is empty we take all columns that are not in id_vars. - if on.is_empty() { - // return empty frame if there are no columns available to use as value vars - if index.len() == self.width() { - let variable_col = Series::new_empty(variable_name, &DataType::String); - let value_col = Series::new_empty(variable_name, &DataType::Null); - - let mut out = self.select(index).unwrap().clear().columns; - out.push(variable_col); - out.push(value_col); - - return Ok(unsafe { DataFrame::new_no_checks(out) }); - } - - let index_set = PlHashSet::from_iter(index.iter().map(|s| s.as_str())); - on = self - .get_columns() - .iter() - .filter_map(|s| { - if index_set.contains(s.name()) { - None - } else { - Some(s.name().into()) - } - }) - .collect(); - } - - // values will all be placed in single column, so we must find their supertype - let schema = self.schema(); - let mut iter = on - .iter() - .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))); - let mut st = iter.next().unwrap()?.clone(); - for dt in iter { - st = try_get_supertype(&st, dt?)?; - } - - // The column name of the variable that is unpivoted - let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); - // prepare ids - let ids_ = self.select_with_schema_unchecked(index, &schema)?; - let mut ids = ids_.clone(); - if ids.width() > 0 { - for _ in 0..on.len() - 1 { - ids.vstack_mut_unchecked(&ids_) - } - } - ids.as_single_chunk_par(); - drop(ids_); - - let mut values = Vec::with_capacity(on.len()); - - for value_column_name in &on { - variable_col.extend_constant(len, Some(value_column_name.as_str())); - // ensure we go via the schema so we are O(1) - // self.column() is linear - // together with this loop that would make it O^2 over `on` - let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?; - let col = &self.columns[pos]; - let value_col = col.cast(&st).map_err( - |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}", col.dtype()), - )?; - values.extend_from_slice(value_col.chunks()) - } - let values_arr = concatenate_owned_unchecked(&values)?; - // SAFETY: - // The give dtype is correct - let values = - unsafe { Series::from_chunks_and_dtype_unchecked(value_name, vec![values_arr], &st) }; - - let variable_col = variable_col.as_box(); - // SAFETY: - // The given dtype is correct - let variables = unsafe { - Series::from_chunks_and_dtype_unchecked( - variable_name, - vec![variable_col], - &DataType::String, - ) - }; - - ids.hstack_mut(&[variables, values])?; - - Ok(ids) - } } #[cfg(test)] @@ -371,13 +200,13 @@ mod test { #[cfg(feature = "dtype-i8")] #[cfg_attr(miri, ignore)] fn test_explode() { - let s0 = Series::new("a", &[1i8, 2, 3]); - let s1 = Series::new("b", &[1i8, 1, 1]); - let s2 = Series::new("c", &[2i8, 2, 2]); - let list = Series::new("foo", &[s0, s1, s2]); + let s0 = Series::new(PlSmallStr::from_static("a"), &[1i8, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("b"), &[1i8, 1, 1]); + let s2 = Series::new(PlSmallStr::from_static("c"), &[2i8, 2, 2]); + let list = Series::new(PlSmallStr::from_static("foo"), &[s0, s1, s2]); - let s0 = Series::new("B", [1, 2, 3]); - let s1 = Series::new("C", [1, 1, 1]); + let s0 = Series::new(PlSmallStr::from_static("B"), [1, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("C"), [1, 1, 1]); let df = DataFrame::new(vec![list, s0.clone(), s1.clone()]).unwrap(); let exploded = df.explode(["foo"]).unwrap(); assert_eq!(exploded.shape(), (9, 3)); @@ -392,11 +221,14 @@ mod test { #[test] #[cfg_attr(miri, ignore)] fn test_explode_df_empty_list() -> PolarsResult<()> { - let s0 = Series::new("a", &[1, 2, 3]); - let s1 = Series::new("b", &[1, 1, 1]); - let list = Series::new("foo", &[s0, s1.clone(), s1.clear()]); - let s0 = Series::new("B", [1, 2, 3]); - let s1 = Series::new("C", [1, 1, 1]); + let s0 = Series::new(PlSmallStr::from_static("a"), &[1, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("b"), &[1, 1, 1]); + let list = Series::new( + PlSmallStr::from_static("foo"), + &[s0, s1.clone(), s1.clear()], + ); + let s0 = Series::new(PlSmallStr::from_static("B"), [1, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("C"), [1, 1, 1]); let df = DataFrame::new(vec![list, s0.clone(), s1.clone()])?; let out = df.explode(["foo"])?; @@ -408,7 +240,10 @@ mod test { assert!(out.equals_missing(&expected)); - let list = Series::new("foo", [s0.clone(), s1.clear(), s1.clone()]); + let list = Series::new( + PlSmallStr::from_static("foo"), + [s0.clone(), s1.clear(), s1.clone()], + ); let df = DataFrame::new(vec![list, s0, s1])?; let out = df.explode(["foo"])?; let expected = df![ @@ -424,9 +259,9 @@ mod test { #[test] #[cfg_attr(miri, ignore)] fn test_explode_single_col() -> PolarsResult<()> { - let s0 = Series::new("a", &[1i32, 2, 3]); - let s1 = Series::new("b", &[1i32, 1, 1]); - let list = Series::new("foo", &[s0, s1]); + let s0 = Series::new(PlSmallStr::from_static("a"), &[1i32, 2, 3]); + let s1 = Series::new(PlSmallStr::from_static("b"), &[1i32, 1, 1]); + let list = Series::new(PlSmallStr::from_static("foo"), &[s0, s1]); let df = DataFrame::new(vec![list])?; let out = df.explode(["foo"])?; @@ -439,55 +274,4 @@ mod test { Ok(()) } - - #[test] - #[cfg_attr(miri, ignore)] - fn test_unpivot() -> PolarsResult<()> { - let df = df!("A" => &["a", "b", "a"], - "B" => &[1, 3, 5], - "C" => &[10, 11, 12], - "D" => &[2, 4, 6] - ) - .unwrap(); - - let unpivoted = df.unpivot(["C", "D"], ["A", "B"])?; - assert_eq!( - Vec::from(unpivoted.column("value")?.i32()?), - &[Some(10), Some(11), Some(12), Some(2), Some(4), Some(6)] - ); - - let args = UnpivotArgs { - on: vec![], - index: vec![], - ..Default::default() - }; - - let unpivoted = df.unpivot2(args).unwrap(); - let value = unpivoted.column("value")?; - // String because of supertype - let value = value.str()?; - let value = value.into_no_null_iter().collect::>(); - assert_eq!( - value, - &["a", "b", "a", "1", "3", "5", "10", "11", "12", "2", "4", "6"] - ); - - let args = UnpivotArgs { - on: vec![], - index: vec!["A".into()], - ..Default::default() - }; - - let unpivoted = df.unpivot2(args).unwrap(); - let value = unpivoted.column("value")?; - let value = value.i32()?; - let value = value.into_no_null_iter().collect::>(); - assert_eq!(value, &[1, 3, 5, 10, 11, 12, 2, 4, 6]); - let variable = unpivoted.column("variable")?; - let variable = variable.str()?; - let variable = variable.into_no_null_iter().collect::>(); - assert_eq!(variable, &["B", "B", "B", "C", "C", "C", "D", "D", "D"]); - assert!(unpivoted.column("A").is_ok()); - Ok(()) - } } diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs index 607fab946857..5c3e1a8cb212 100644 --- a/crates/polars-core/src/frame/from.rs +++ b/crates/polars-core/src/frame/from.rs @@ -17,9 +17,9 @@ impl TryFrom for DataFrame { // reported data type is correct unsafe { Series::_try_from_arrow_unchecked_with_md( - &fld.name, + fld.name.clone(), vec![arr], - fld.data_type(), + fld.dtype(), Some(&fld.metadata), ) } diff --git a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs index ef57f7bb953c..3e71953c5753 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs @@ -74,19 +74,19 @@ where list_values.into(), validity, ); - let data_type = ListArray::::default_datatype( + let dtype = ListArray::::default_datatype( T::get_dtype().to_arrow(CompatLevel::newest()), ); // SAFETY: // offsets are monotonically increasing let arr = ListArray::::new( - data_type, + dtype, Offsets::new_unchecked(offsets).into(), Box::new(array), None, ); - let mut ca = ListChunked::with_chunk(self.name(), arr); + let mut ca = ListChunked::with_chunk(self.name().clone(), arr); if can_fast_explode { ca.set_fast_explode() } @@ -139,16 +139,16 @@ where list_values.into(), validity, ); - let data_type = ListArray::::default_datatype( + let dtype = ListArray::::default_datatype( T::get_dtype().to_arrow(CompatLevel::newest()), ); let arr = ListArray::::new( - data_type, + dtype, Offsets::new_unchecked(offsets).into(), Box::new(array), None, ); - let mut ca = ListChunked::with_chunk(self.name(), arr); + let mut ca = ListChunked::with_chunk(self.name().clone(), arr); if can_fast_explode { ca.set_fast_explode() } @@ -162,14 +162,14 @@ impl AggList for NullChunked { unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { match groups { GroupsProxy::Idx(groups) => { - let mut builder = ListNullChunkedBuilder::new(self.name(), groups.len()); + let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len()); for idx in groups.all().iter() { builder.append_with_len(idx.len()); } builder.finish().into_series() }, GroupsProxy::Slice { groups, .. } => { - let mut builder = ListNullChunkedBuilder::new(self.name(), groups.len()); + let mut builder = ListNullChunkedBuilder::new(self.name().clone(), groups.len()); for [_, len] in groups { builder.append_with_len(*len as usize); } @@ -259,17 +259,17 @@ impl AggList for ObjectChunked { // the pointer does not fail. pe.set_to_series_fn::(); let extension_array = Box::new(pe.take_and_forget()) as ArrayRef; - let extension_dtype = extension_array.data_type(); + let extension_dtype = extension_array.dtype(); - let data_type = ListArray::::default_datatype(extension_dtype.clone()); + let dtype = ListArray::::default_datatype(extension_dtype.clone()); // SAFETY: offsets are monotonically increasing. let arr = ListArray::::new( - data_type, + dtype, Offsets::new_unchecked(offsets).into(), extension_array, None, ); - let mut listarr = ListChunked::with_chunk(self.name(), arr); + let mut listarr = ListChunked::with_chunk(self.name().clone(), arr); if can_fast_explode { listarr.set_fast_explode() } @@ -291,10 +291,12 @@ impl AggList for StructChunked { }; let arr = gathered.chunks()[0].clone(); - let dtype = LargeListArray::default_datatype(arr.data_type().clone()); + let dtype = LargeListArray::default_datatype(arr.dtype().clone()); - let mut chunk = - ListChunked::with_chunk(self.name(), LargeListArray::new(dtype, offsets, arr, None)); + let mut chunk = ListChunked::with_chunk( + self.name().clone(), + LargeListArray::new(dtype, offsets, arr, None), + ); chunk.set_dtype(DataType::List(Box::new(self.dtype().clone()))); if can_fast_explode { chunk.set_fast_explode() @@ -320,10 +322,12 @@ where }; let arr = gathered.chunks()[0].clone(); - let dtype = LargeListArray::default_datatype(arr.data_type().clone()); + let dtype = LargeListArray::default_datatype(arr.dtype().clone()); - let mut chunk = - ListChunked::with_chunk(ca.name(), LargeListArray::new(dtype, offsets, arr, None)); + let mut chunk = ListChunked::with_chunk( + ca.name().clone(), + LargeListArray::new(dtype, offsets, arr, None), + ); chunk.set_dtype(DataType::List(Box::new(ca.dtype().clone()))); if can_fast_explode { chunk.set_fast_explode() diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index 447352a0faaf..fe71148cd49b 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -73,7 +73,7 @@ impl Series { } }, ) - .collect_ca(""); + .collect_ca(PlSmallStr::EMPTY); // SAFETY: groups are always in bounds. s.take_unchecked(&indices) }, @@ -81,7 +81,7 @@ impl Series { let indices = groups .iter() .map(|&[first, len]| if len == 0 { None } else { Some(first) }) - .collect_ca(""); + .collect_ca(PlSmallStr::EMPTY); // SAFETY: groups are always in bounds. s.take_unchecked(&indices) }, @@ -175,7 +175,7 @@ impl Series { * (MS_IN_DAY as f64)) .cast(&Datetime(TimeUnit::Milliseconds, None)) .unwrap(), - _ => Series::full_null("", groups.len(), s.dtype()), + _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()), } } @@ -227,7 +227,7 @@ impl Series { * (MS_IN_DAY as f64)) .cast(&Datetime(TimeUnit::Milliseconds, None)) .unwrap(), - _ => Series::full_null("", groups.len(), s.dtype()), + _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()), } } @@ -262,7 +262,7 @@ impl Series { s } }, - _ => Series::full_null("", groups.len(), s.dtype()), + _ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()), } } @@ -287,7 +287,7 @@ impl Series { Some(idx[idx.len() - 1]) } }) - .collect_ca(""); + .collect_ca(PlSmallStr::EMPTY); s.take_unchecked(&indices) }, GroupsProxy::Slice { groups, .. } => { @@ -300,7 +300,7 @@ impl Series { Some(first + len - 1) } }) - .collect_ca(""); + .collect_ca(PlSmallStr::EMPTY); s.take_unchecked(&indices) }, }; diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index bd805d4bd332..f30f60089750 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -359,7 +359,7 @@ where { let invalid_quantile = !(0.0..=1.0).contains(&quantile); if invalid_quantile { - return Series::full_null(ca.name(), groups.len(), ca.dtype()); + return Series::full_null(ca.name().clone(), groups.len(), ca.dtype()); } match groups { GroupsProxy::Idx(groups) => { diff --git a/crates/polars-core/src/frame/group_by/expr.rs b/crates/polars-core/src/frame/group_by/expr.rs index 160f9e81f8d3..f35a04a5664f 100644 --- a/crates/polars-core/src/frame/group_by/expr.rs +++ b/crates/polars-core/src/frame/group_by/expr.rs @@ -4,5 +4,5 @@ pub trait PhysicalAggExpr { #[allow(clippy::ptr_arg)] fn evaluate(&self, df: &DataFrame, groups: &GroupsProxy) -> PolarsResult; - fn root_name(&self) -> PolarsResult<&str>; + fn root_name(&self) -> PolarsResult<&PlSmallStr>; } diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index 47f585006e81..bdaa439a1232 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -324,7 +324,7 @@ impl IntoGroupsProxy for ListChunked { let ca = if multithreaded { encode_rows_vertical_par_unordered(by).unwrap() } else { - _get_rows_encoded_ca_unordered("", by).unwrap() + _get_rows_encoded_ca_unordered(PlSmallStr::EMPTY, by).unwrap() }; ca.group_tuples(multithreaded, sorted) diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 5206240a3261..5dd631a51f0f 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -2,6 +2,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use num_traits::NumCast; +use polars_utils::format_pl_smallstr; use polars_utils::hashing::DirtyHash; use rayon::prelude::*; @@ -113,7 +114,7 @@ impl DataFrame { pub fn group_by(&self, by: I) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { let selected_keys = self.select_series(by)?; self.group_by_with_series(selected_keys, true, false) @@ -124,7 +125,7 @@ impl DataFrame { pub fn group_by_stable(&self, by: I) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { let selected_keys = self.select_series(by)?; self.group_by_with_series(selected_keys, true, true) @@ -152,9 +153,9 @@ impl DataFrame { /// let s0 = DateChunked::parse_from_str_slice("date", dates, fmt) /// .into_series(); /// // create temperature series -/// let s1 = Series::new("temp", [20, 10, 7, 9, 1]); +/// let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1]); /// // create rain series -/// let s2 = Series::new("rain", [0.2, 0.1, 0.3, 0.1, 0.01]); +/// let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01]); /// // create a new DataFrame /// let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); /// println!("{:?}", df); @@ -187,7 +188,7 @@ pub struct GroupBy<'df> { // [first idx, [other idx]] groups: GroupsProxy, // columns selected for aggregation - pub(crate) selected_agg: Option>, + pub(crate) selected_agg: Option>, } impl<'df> GroupBy<'df> { @@ -195,7 +196,7 @@ impl<'df> GroupBy<'df> { df: &'df DataFrame, by: Vec, groups: GroupsProxy, - selected_agg: Option>, + selected_agg: Option>, ) -> Self { GroupBy { df, @@ -211,13 +212,8 @@ impl<'df> GroupBy<'df> { /// Note that making a selection with this method is not required. If you /// skip it all columns (except for the keys) will be selected for aggregation. #[must_use] - pub fn select, S: AsRef>(mut self, selection: I) -> Self { - self.selected_agg = Some( - selection - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect(), - ); + pub fn select, S: Into>(mut self, selection: I) -> Self { + self.selected_agg = Some(selection.into_iter().map(|s| s.into()).collect()); self } @@ -285,7 +281,10 @@ impl<'df> GroupBy<'df> { ); } - let indices = groups.iter().map(|&[first, _len]| first).collect_ca(""); + let indices = groups + .iter() + .map(|&[first, _len]| first) + .collect_ca(PlSmallStr::EMPTY); // SAFETY: groups are always in bounds. let mut out = unsafe { s.take_unchecked(&indices) }; // Sliced groups are always in order of discovery. @@ -303,21 +302,24 @@ impl<'df> GroupBy<'df> { } fn prepare_agg(&self) -> PolarsResult<(Vec, Vec)> { - let selection = match &self.selected_agg { - Some(selection) => selection.clone(), + let keys = self.keys(); + + let agg_col = match &self.selected_agg { + Some(selection) => self.df.select_series_impl(selection.as_slice()), None => { let by: Vec<_> = self.selected_keys.iter().map(|s| s.name()).collect(); - self.df - .get_column_names() - .into_iter() + let selection = self + .df + .iter() + .map(|s| s.name()) .filter(|a| !by.contains(a)) - .map(|s| s.to_string()) - .collect() + .cloned() + .collect::>(); + + self.df.select_series_impl(selection.as_slice()) }, - }; + }?; - let keys = self.keys(); - let agg_col = self.df.select_series(selection)?; Ok((keys, agg_col)) } @@ -328,7 +330,7 @@ impl<'df> GroupBy<'df> { /// ```rust /// # use polars_core::prelude::*; /// fn example(df: DataFrame) -> PolarsResult { - /// df.group_by(["date"])?.select(&["temp", "rain"]).mean() + /// df.group_by(["date"])?.select(["temp", "rain"]).mean() /// } /// ``` /// Returns: @@ -351,9 +353,9 @@ impl<'df> GroupBy<'df> { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Mean); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Mean); let mut agg = unsafe { agg_col.agg_mean(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -389,9 +391,9 @@ impl<'df> GroupBy<'df> { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Sum); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Sum); let mut agg = unsafe { agg_col.agg_sum(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -426,9 +428,9 @@ impl<'df> GroupBy<'df> { pub fn min(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Min); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Min); let mut agg = unsafe { agg_col.agg_min(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -463,9 +465,9 @@ impl<'df> GroupBy<'df> { pub fn max(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Max); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Max); let mut agg = unsafe { agg_col.agg_max(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -500,9 +502,9 @@ impl<'df> GroupBy<'df> { pub fn first(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::First); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::First); let mut agg = unsafe { agg_col.agg_first(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -537,9 +539,9 @@ impl<'df> GroupBy<'df> { pub fn last(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Last); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Last); let mut agg = unsafe { agg_col.agg_last(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -574,9 +576,9 @@ impl<'df> GroupBy<'df> { pub fn n_unique(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::NUnique); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::NUnique); let mut agg = unsafe { agg_col.agg_n_unique(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg.into_series()); } DataFrame::new(cols) @@ -606,10 +608,12 @@ impl<'df> GroupBy<'df> { ); let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = - fmt_group_by_column(agg_col.name(), GroupByMethod::Quantile(quantile, interpol)); + let new_name = fmt_group_by_column( + agg_col.name().as_str(), + GroupByMethod::Quantile(quantile, interpol), + ); let mut agg = unsafe { agg_col.agg_quantile(&self.groups, quantile, interpol) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg.into_series()); } DataFrame::new(cols) @@ -629,9 +633,9 @@ impl<'df> GroupBy<'df> { pub fn median(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Median); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Median); let mut agg = unsafe { agg_col.agg_median(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg.into_series()); } DataFrame::new(cols) @@ -642,9 +646,9 @@ impl<'df> GroupBy<'df> { pub fn var(&self, ddof: u8) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Var(ddof)); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Var(ddof)); let mut agg = unsafe { agg_col.agg_var(&self.groups, ddof) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg.into_series()); } DataFrame::new(cols) @@ -655,9 +659,9 @@ impl<'df> GroupBy<'df> { pub fn std(&self, ddof: u8) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Std(ddof)); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Std(ddof)); let mut agg = unsafe { agg_col.agg_std(&self.groups, ddof) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg.into_series()); } DataFrame::new(cols) @@ -693,13 +697,13 @@ impl<'df> GroupBy<'df> { for agg_col in agg_cols { let new_name = fmt_group_by_column( - agg_col.name(), + agg_col.name().as_str(), GroupByMethod::Count { include_nulls: true, }, ); let mut ca = self.groups.group_count(); - ca.rename(&new_name); + ca.rename(new_name); cols.push(ca.into_series()); } DataFrame::new(cols) @@ -734,7 +738,7 @@ impl<'df> GroupBy<'df> { let mut cols = self.keys(); let mut column = self.groups.as_list_chunked(); let new_name = fmt_group_by_column("", GroupByMethod::Groups); - column.rename(&new_name); + column.rename(new_name); cols.push(column.into_series()); DataFrame::new(cols) } @@ -769,9 +773,9 @@ impl<'df> GroupBy<'df> { pub fn agg_list(&self) -> PolarsResult { let (mut cols, agg_cols) = self.prepare_agg()?; for agg_col in agg_cols { - let new_name = fmt_group_by_column(agg_col.name(), GroupByMethod::Implode); + let new_name = fmt_group_by_column(agg_col.name().as_str(), GroupByMethod::Implode); let mut agg = unsafe { agg_col.agg_list(&self.groups) }; - agg.rename(&new_name); + agg.rename(new_name); cols.push(agg); } DataFrame::new(cols) @@ -785,7 +789,7 @@ impl<'df> GroupBy<'df> { } else { let mut new_cols = Vec::with_capacity(self.selected_keys.len() + agg.len()); new_cols.extend_from_slice(&self.selected_keys); - let cols = self.df.select_series(agg)?; + let cols = self.df.select_series_impl(agg.as_slice())?; new_cols.extend(cols); Ok(unsafe { DataFrame::new_no_checks(new_cols) }) } @@ -893,25 +897,25 @@ impl Display for GroupByMethod { } // Formatting functions used in eager and lazy code for renaming grouped columns -pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> String { +pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr { use GroupByMethod::*; match method { - Min => format!("{name}_min"), - Max => format!("{name}_max"), - NanMin => format!("{name}_nan_min"), - NanMax => format!("{name}_nan_max"), - Median => format!("{name}_median"), - Mean => format!("{name}_mean"), - First => format!("{name}_first"), - Last => format!("{name}_last"), - Sum => format!("{name}_sum"), - Groups => "groups".to_string(), - NUnique => format!("{name}_n_unique"), - Count { .. } => format!("{name}_count"), - Implode => format!("{name}_agg_list"), - Quantile(quantile, _interpol) => format!("{name}_quantile_{quantile:.2}"), - Std(_) => format!("{name}_agg_std"), - Var(_) => format!("{name}_agg_var"), + Min => format_pl_smallstr!("{name}_min"), + Max => format_pl_smallstr!("{name}_max"), + NanMin => format_pl_smallstr!("{name}_nan_min"), + NanMax => format_pl_smallstr!("{name}_nan_max"), + Median => format_pl_smallstr!("{name}_median"), + Mean => format_pl_smallstr!("{name}_mean"), + First => format_pl_smallstr!("{name}_first"), + Last => format_pl_smallstr!("{name}_last"), + Sum => format_pl_smallstr!("{name}_sum"), + Groups => PlSmallStr::from_static("groups"), + NUnique => format_pl_smallstr!("{name}_n_unique"), + Count { .. } => format_pl_smallstr!("{name}_count"), + Implode => format_pl_smallstr!("{name}_agg_list"), + Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"), + Std(_) => format_pl_smallstr!("{name}_agg_std"), + Var(_) => format_pl_smallstr!("{name}_agg_var"), } } @@ -926,7 +930,7 @@ mod test { #[cfg_attr(miri, ignore)] fn test_group_by() -> PolarsResult<()> { let s0 = Series::new( - "date", + PlSmallStr::from_static("date"), &[ "2020-08-21", "2020-08-21", @@ -935,14 +939,14 @@ mod test { "2020-08-22", ], ); - let s1 = Series::new("temp", [20, 10, 7, 9, 1]); - let s2 = Series::new("rain", [0.2, 0.1, 0.3, 0.1, 0.01]); + let s1 = Series::new(PlSmallStr::from_static("temp"), [20, 10, 7, 9, 1]); + let s2 = Series::new(PlSmallStr::from_static("rain"), [0.2, 0.1, 0.3, 0.1, 0.01]); let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); let out = df.group_by_stable(["date"])?.select(["temp"]).count()?; assert_eq!( out.column("temp_count")?, - &Series::new("temp_count", [2 as IdxSize, 2, 1]) + &Series::new(PlSmallStr::from_static("temp_count"), [2 as IdxSize, 2, 1]) ); // Use of deprecated mean() for testing purposes @@ -954,7 +958,7 @@ mod test { .mean()?; assert_eq!( out.column("temp_mean")?, - &Series::new("temp_mean", [15.0f64, 4.0, 9.0]) + &Series::new(PlSmallStr::from_static("temp_mean"), [15.0f64, 4.0, 9.0]) ); // Use of deprecated `mean()` for testing purposes @@ -971,7 +975,7 @@ mod test { let out = df.group_by_stable(["date"])?.select(["temp"]).sum()?; assert_eq!( out.column("temp_sum")?, - &Series::new("temp_sum", [30, 8, 9]) + &Series::new(PlSmallStr::from_static("temp_sum"), [30, 8, 9]) ); // Use of deprecated `n_unique()` for testing purposes @@ -987,19 +991,19 @@ mod test { #[cfg_attr(miri, ignore)] fn test_static_group_by_by_12_columns() { // Build GroupBy DataFrame. - let s0 = Series::new("G1", ["A", "A", "B", "B", "C"].as_ref()); - let s1 = Series::new("N", [1, 2, 2, 4, 2].as_ref()); - let s2 = Series::new("G2", ["k", "l", "m", "m", "l"].as_ref()); - let s3 = Series::new("G3", ["a", "b", "c", "c", "d"].as_ref()); - let s4 = Series::new("G4", ["1", "2", "3", "3", "4"].as_ref()); - let s5 = Series::new("G5", ["X", "Y", "Z", "Z", "W"].as_ref()); - let s6 = Series::new("G6", [false, true, true, true, false].as_ref()); - let s7 = Series::new("G7", ["r", "x", "q", "q", "o"].as_ref()); - let s8 = Series::new("G8", ["R", "X", "Q", "Q", "O"].as_ref()); - let s9 = Series::new("G9", [1, 2, 3, 3, 4].as_ref()); - let s10 = Series::new("G10", [".", "!", "?", "?", "/"].as_ref()); - let s11 = Series::new("G11", ["(", ")", "@", "@", "$"].as_ref()); - let s12 = Series::new("G12", ["-", "_", ";", ";", ","].as_ref()); + let s0 = Series::new("G1".into(), ["A", "A", "B", "B", "C"].as_ref()); + let s1 = Series::new("N".into(), [1, 2, 2, 4, 2].as_ref()); + let s2 = Series::new("G2".into(), ["k", "l", "m", "m", "l"].as_ref()); + let s3 = Series::new("G3".into(), ["a", "b", "c", "c", "d"].as_ref()); + let s4 = Series::new("G4".into(), ["1", "2", "3", "3", "4"].as_ref()); + let s5 = Series::new("G5".into(), ["X", "Y", "Z", "Z", "W"].as_ref()); + let s6 = Series::new("G6".into(), [false, true, true, true, false].as_ref()); + let s7 = Series::new("G7".into(), ["r", "x", "q", "q", "o"].as_ref()); + let s8 = Series::new("G8".into(), ["R", "X", "Q", "Q", "O"].as_ref()); + let s9 = Series::new("G9".into(), [1, 2, 3, 3, 4].as_ref()); + let s10 = Series::new("G10".into(), [".", "!", "?", "?", "/"].as_ref()); + let s11 = Series::new("G11".into(), ["(", ")", "@", "@", "$"].as_ref()); + let s12 = Series::new("G12".into(), ["-", "_", ";", ";", ","].as_ref()); let df = DataFrame::new(vec![s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]).unwrap(); @@ -1036,13 +1040,13 @@ mod test { let mut series = Vec::with_capacity(14); // Create a series for every group name. - for series_name in &series_names { - let group_series = Series::new(series_name, series_content.as_ref()); + for series_name in series_names { + let group_series = Series::new(series_name.into(), series_content.as_ref()); series.push(group_series); } // Create a series for the aggregation column. - let agg_series = Series::new("N", [1, 2, 3, 3, 4].as_ref()); + let agg_series = Series::new("N".into(), [1, 2, 3, 3, 4].as_ref()); series.push(agg_series); // Create the dataframe with the computed series. diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index 63f37368b756..d9aedc261faf 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -345,7 +345,7 @@ impl GroupsProxy { } unsafe { ( - Some(IdxCa::from_vec("", gather_offsets)), + Some(IdxCa::from_vec(PlSmallStr::EMPTY, gather_offsets)), OffsetsBuffer::new_unchecked(list_offset.into()), can_fast_explode, ) @@ -369,7 +369,7 @@ impl GroupsProxy { unsafe { ( - Some(IdxCa::from_vec("", gather_offsets)), + Some(IdxCa::from_vec(PlSmallStr::EMPTY, gather_offsets)), OffsetsBuffer::new_unchecked(list_offset.into()), can_fast_explode, ) diff --git a/crates/polars-core/src/frame/horizontal.rs b/crates/polars-core/src/frame/horizontal.rs index 6ed4e8bbb356..bcbf486e0877 100644 --- a/crates/polars-core/src/frame/horizontal.rs +++ b/crates/polars-core/src/frame/horizontal.rs @@ -3,11 +3,11 @@ use polars_utils::aliases::PlHashSet; use crate::datatypes::AnyValue; use crate::frame::DataFrame; -use crate::prelude::{Series, SmartString}; +use crate::prelude::{PlSmallStr, Series}; -fn check_hstack<'a>( - col: &'a Series, - names: &mut PlHashSet<&'a str>, +fn check_hstack( + col: &Series, + names: &mut PlHashSet, height: usize, is_empty: bool, ) -> PolarsResult<()> { @@ -17,8 +17,8 @@ fn check_hstack<'a>( col.len(), height, ); polars_ensure!( - names.insert(col.name()), - Duplicate: "unable to hstack, column with name {:?} already exists", col.name(), + names.insert(col.name().clone()), + Duplicate: "unable to hstack, column with name {:?} already exists", col.name().as_str(), ); Ok(()) } @@ -50,7 +50,7 @@ impl DataFrame { let mut names = self .columns .iter() - .map(|c| c.name()) + .map(|c| c.name().clone()) .collect::>(); let height = self.height(); @@ -99,15 +99,12 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars let height = first_df.height(); let is_empty = first_df.is_empty(); - let columns; let mut names = if check_duplicates { - columns = first_df + first_df .columns .iter() - .map(|s| SmartString::from(s.name())) - .collect::>(); - - columns.iter().map(|n| n.as_str()).collect::>() + .map(|s| s.name().clone()) + .collect::>() } else { Default::default() }; diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 45e110ba7fa1..648141688db8 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use std::{mem, ops}; -use ahash::AHashSet; +use polars_utils::itertools::Itertools; use rayon::prelude::*; #[cfg(feature = "algorithm_group_by")] @@ -27,9 +27,9 @@ mod top_k; mod upstream_traits; use arrow::record_batch::RecordBatch; +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use smartstring::alias::String as SmartString; use crate::chunked_array::cast::CastOptions; #[cfg(feature = "row_hash")] @@ -61,6 +61,37 @@ pub enum UniqueKeepStrategy { Any, } +fn ensure_names_unique(items: &[T], mut get_name: F) -> PolarsResult<()> +where + F: for<'a> FnMut(&'a T) -> &'a str, +{ + // Always unique. + if items.len() <= 1 { + return Ok(()); + } + + if items.len() <= 4 { + // Too small to be worth spawning a hashmap for, this is at most 6 comparisons. + for i in 0..items.len() - 1 { + let name = get_name(&items[i]); + for other in items.iter().skip(i + 1) { + if name == get_name(other) { + polars_bail!(duplicate = name); + } + } + } + } else { + let mut names = PlHashSet::with_capacity(items.len()); + for item in items { + let name = get_name(item); + if !names.insert(name) { + polars_bail!(duplicate = name); + } + } + } + Ok(()) +} + /// A contiguous growable collection of `Series` that have the same length. /// /// ## Use declarations @@ -89,8 +120,8 @@ pub enum UniqueKeepStrategy { /// /// ```rust /// # use polars_core::prelude::*; -/// let s1 = Series::new("Fruit", &["Apple", "Apple", "Pear"]); -/// let s2 = Series::new("Color", &["Red", "Yellow", "Green"]); +/// let s1 = Series::new("Fruit".into(), ["Apple", "Apple", "Pear"]); +/// let s2 = Series::new("Color".into(), ["Red", "Yellow", "Green"]); /// /// let df: PolarsResult = DataFrame::new(vec![s1, s2]); /// ``` @@ -101,8 +132,8 @@ pub enum UniqueKeepStrategy { /// /// ```rust /// # use polars_core::prelude::*; -/// let df: PolarsResult = df!("Fruit" => &["Apple", "Apple", "Pear"], -/// "Color" => &["Red", "Yellow", "Green"]); +/// let df: PolarsResult = df!("Fruit" => ["Apple", "Apple", "Pear"], +/// "Color" => ["Red", "Yellow", "Green"]); /// ``` /// /// ## Using a CSV file @@ -116,11 +147,11 @@ pub enum UniqueKeepStrategy { /// /// ```rust /// # use polars_core::prelude::*; -/// let df = df!("Fruit" => &["Apple", "Apple", "Pear"], -/// "Color" => &["Red", "Yellow", "Green"])?; +/// let df = df!("Fruit" => ["Apple", "Apple", "Pear"], +/// "Color" => ["Red", "Yellow", "Green"])?; /// -/// assert_eq!(df[0], Series::new("Fruit", &["Apple", "Apple", "Pear"])); -/// assert_eq!(df[1], Series::new("Color", &["Red", "Yellow", "Green"])); +/// assert_eq!(df[0], Series::new("Fruit".into(), &["Apple", "Apple", "Pear"])); +/// assert_eq!(df[1], Series::new("Color".into(), &["Red", "Yellow", "Green"])); /// # Ok::<(), PolarsError>(()) /// ``` /// @@ -128,11 +159,11 @@ pub enum UniqueKeepStrategy { /// /// ```rust /// # use polars_core::prelude::*; -/// let df = df!("Fruit" => &["Apple", "Apple", "Pear"], -/// "Color" => &["Red", "Yellow", "Green"])?; +/// let df = df!("Fruit" => ["Apple", "Apple", "Pear"], +/// "Color" => ["Red", "Yellow", "Green"])?; /// -/// assert_eq!(df["Fruit"], Series::new("Fruit", &["Apple", "Apple", "Pear"])); -/// assert_eq!(df["Color"], Series::new("Color", &["Red", "Yellow", "Green"])); +/// assert_eq!(df["Fruit"], Series::new("Fruit".into(), &["Apple", "Apple", "Pear"])); +/// assert_eq!(df["Color"], Series::new("Color".into(), &["Red", "Yellow", "Green"])); /// # Ok::<(), PolarsError>(()) /// ``` #[derive(Clone)] @@ -194,7 +225,7 @@ impl DataFrame { fn check_already_present(&self, name: &str) -> PolarsResult<()> { polars_ensure!( - self.columns.iter().all(|s| s.name() != name), + self.columns.iter().all(|s| s.name().as_str() != name), Duplicate: "column with name {:?} is already present in the DataFrame", name ); Ok(()) @@ -215,95 +246,68 @@ impl DataFrame { /// /// ``` /// # use polars_core::prelude::*; - /// let s0 = Series::new("days", [0, 1, 2].as_ref()); - /// let s1 = Series::new("temp", [22.1, 19.9, 7.].as_ref()); + /// let s0 = Series::new("days".into(), [0, 1, 2].as_ref()); + /// let s1 = Series::new("temp".into(), [22.1, 19.9, 7.].as_ref()); /// /// let df = DataFrame::new(vec![s0, s1])?; /// # Ok::<(), PolarsError>(()) /// ``` - pub fn new(columns: Vec) -> PolarsResult { - let mut first_len = None; + pub fn new(columns: Vec) -> PolarsResult { + ensure_names_unique(&columns, |s| s.name().as_str())?; - let shape_err = |&first_name, &first_len, &name, &len| { - polars_bail!( - ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} \ - while series {:?} has length {}", - first_name, first_len, name, len - ); - }; + if columns.len() > 1 { + let first_len = columns[0].len(); + for col in &columns { + polars_ensure!( + col.len() == first_len, + ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", + columns[0].len(), first_len, col.name(), col.len() + ); + } + } - let series_cols = if S::is_series() { - // SAFETY: - // we are guarded by the type system here. - #[allow(clippy::transmute_undefined_repr)] - let series_cols = unsafe { std::mem::transmute::, Vec>(columns) }; - let mut names = PlHashSet::with_capacity(series_cols.len()); - - for s in &series_cols { - let name = s.name(); - - match first_len { - Some(len) => { - if s.len() != len { - let first_series = &series_cols.first().unwrap(); - return shape_err( - &first_series.name(), - &first_series.len(), - &name, - &s.len(), - ); - } - }, - None => first_len = Some(s.len()), - } + Ok(DataFrame { columns }) + } - if !names.insert(name) { - polars_bail!(duplicate = name); - } - } - // we drop early as the brchk thinks the &str borrows are used when calling the drop - // of both `series_cols` and `names` - drop(names); - series_cols - } else { - let mut series_cols: Vec = Vec::with_capacity(columns.len()); - let mut names = PlHashSet::with_capacity(columns.len()); - - // check for series length equality and convert into series in one pass - for s in columns { - let series = s.into_series(); - // we have aliasing borrows so we must allocate a string - let name = series.name().to_string(); - - match first_len { - Some(len) => { - if series.len() != len { - let first_series = &series_cols.first().unwrap(); - return shape_err( - &first_series.name(), - &first_series.len(), - &name.as_str(), - &series.len(), - ); - } - }, - None => first_len = Some(series.len()), - } + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + pub fn new_with_broadcast(columns: Vec) -> PolarsResult { + ensure_names_unique(&columns, |s| s.name().as_str())?; + unsafe { Self::new_with_broadcast_no_checks(columns) } + } - if names.contains(&name) { - polars_bail!(duplicate = name); + /// Converts a sequence of columns into a DataFrame, broadcasting length-1 + /// columns to match the other columns. + /// + /// # Safety + /// Does not check that the column names are unique (which they must be). + pub unsafe fn new_with_broadcast_no_checks(mut columns: Vec) -> PolarsResult { + // The length of the longest non-unit length column determines the + // broadcast length. If all columns are unit-length the broadcast length + // is one. + let broadcast_len = columns + .iter() + .map(|s| s.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + + for col in &mut columns { + // Length not equal to the broadcast len, needs broadcast or is an error. + let len = col.len(); + if len != broadcast_len { + if len != 1 { + let name = col.name().to_owned(); + let longest_column = columns.iter().max_by_key(|c| c.len()).unwrap().name(); + polars_bail!( + ShapeMismatch: "could not create a new DataFrame: series {:?} has length {} while series {:?} has length {}", + name, len, longest_column, broadcast_len + ); } - - series_cols.push(series); - names.insert(name); + *col = col.new_from_index(0, broadcast_len); } - drop(names); - series_cols - }; - - Ok(DataFrame { - columns: series_cols, - }) + } + Ok(unsafe { DataFrame::new_no_checks(columns) }) } /// Creates an empty `DataFrame` usable in a compile time context (such as static initializers). @@ -323,7 +327,7 @@ impl DataFrame { pub fn empty_with_schema(schema: &Schema) -> Self { let cols = schema .iter() - .map(|(name, dtype)| Series::new_empty(name, dtype)) + .map(|(name, dtype)| Series::new_empty(name.clone(), dtype)) .collect(); unsafe { DataFrame::new_no_checks(cols) } } @@ -331,9 +335,8 @@ impl DataFrame { /// Create an empty `DataFrame` with empty columns as per the `schema`. pub fn empty_with_arrow_schema(schema: &ArrowSchema) -> Self { let cols = schema - .fields - .iter() - .map(|fld| Series::new_empty(fld.name.as_str(), &(fld.data_type().into()))) + .iter_values() + .map(|fld| Series::new_empty(fld.name.clone(), &(fld.dtype().into()))) .collect(); unsafe { DataFrame::new_no_checks(cols) } } @@ -344,8 +347,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let s1 = Series::new("Ocean", &["Atlantic", "Indian"]); - /// let s2 = Series::new("Area (km²)", &[106_460_000, 70_560_000]); + /// let s1 = Series::new("Ocean".into(), ["Atlantic", "Indian"]); + /// let s2 = Series::new("Area (km²)".into(), [106_460_000, 70_560_000]); /// let mut df = DataFrame::new(vec![s1.clone(), s2.clone()])?; /// /// assert_eq!(df.pop(), Some(s2)); @@ -364,10 +367,10 @@ impl DataFrame { /// /// ``` /// # use polars_core::prelude::*; - /// let df1: DataFrame = df!("Name" => &["James", "Mary", "John", "Patricia"])?; + /// let df1: DataFrame = df!("Name" => ["James", "Mary", "John", "Patricia"])?; /// assert_eq!(df1.shape(), (4, 1)); /// - /// let df2: DataFrame = df1.with_row_index("Id", None)?; + /// let df2: DataFrame = df1.with_row_index("Id".into(), None)?; /// assert_eq!(df2.shape(), (4, 2)); /// println!("{}", df2); /// @@ -392,7 +395,7 @@ impl DataFrame { /// | 3 | Patricia | /// +-----+----------+ /// ``` - pub fn with_row_index(&self, name: &str, offset: Option) -> PolarsResult { + pub fn with_row_index(&self, name: PlSmallStr, offset: Option) -> PolarsResult { let mut columns = Vec::with_capacity(self.columns.len() + 1); let offset = offset.unwrap_or(0); @@ -408,7 +411,7 @@ impl DataFrame { } /// Add a row index column in place. - pub fn with_row_index_mut(&mut self, name: &str, offset: Option) -> &mut Self { + pub fn with_row_index_mut(&mut self, name: PlSmallStr, offset: Option) -> &mut Self { let offset = offset.unwrap_or(0); let mut ca = IdxCa::from_vec( name, @@ -442,16 +445,7 @@ impl DataFrame { /// It is the callers responsibility to uphold the contract of all `Series` /// having an equal length, if not this may panic down the line. pub unsafe fn new_no_length_checks(columns: Vec) -> PolarsResult { - let mut names = PlHashSet::with_capacity(columns.len()); - for column in &columns { - let name = column.name(); - if !names.insert(name) { - polars_bail!(duplicate = name); - } - } - // we drop early as the brchk thinks the &str borrows are used when calling the drop - // of both `columns` and `names` - drop(names); + ensure_names_unique(&columns, |s| s.name().as_str())?; Ok(DataFrame { columns }) } @@ -533,18 +527,21 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Thing" => &["Observable universe", "Human stupidity"], - /// "Diameter (m)" => &[8.8e26, f64::INFINITY])?; + /// let df: DataFrame = df!("Thing" => ["Observable universe", "Human stupidity"], + /// "Diameter (m)" => [8.8e26, f64::INFINITY])?; /// - /// let f1: Field = Field::new("Thing", DataType::String); - /// let f2: Field = Field::new("Diameter (m)", DataType::Float64); + /// let f1: Field = Field::new("Thing".into(), DataType::String); + /// let f2: Field = Field::new("Diameter (m)".into(), DataType::Float64); /// let sc: Schema = Schema::from_iter(vec![f1, f2]); /// /// assert_eq!(df.schema(), sc); /// # Ok::<(), PolarsError>(()) /// ``` pub fn schema(&self) -> Schema { - self.columns.as_slice().into() + self.columns + .iter() + .map(|x| (x.name().clone(), x.dtype().clone())) + .collect() } /// Get a reference to the [`DataFrame`] columns. @@ -553,8 +550,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Name" => &["Adenine", "Cytosine", "Guanine", "Thymine"], - /// "Symbol" => &["A", "C", "G", "T"])?; + /// let df: DataFrame = df!("Name" => ["Adenine", "Cytosine", "Guanine", "Thymine"], + /// "Symbol" => ["A", "C", "G", "T"])?; /// let columns: &[Series] = df.get_columns(); /// /// assert_eq!(columns[0].name(), "Name"); @@ -575,14 +572,19 @@ impl DataFrame { &mut self.columns } + /// Take ownership of the underlying columns vec. + pub fn take_columns(self) -> Vec { + self.columns + } + /// Iterator over the columns as [`Series`]. /// /// # Example /// /// ```rust /// # use polars_core::prelude::*; - /// let s1: Series = Series::new("Name", &["Pythagoras' theorem", "Shannon entropy"]); - /// let s2: Series = Series::new("Formula", &["a²+b²=c²", "H=-Σ[P(x)log|P(x)|]"]); + /// let s1: Series = Series::new("Name".into(), ["Pythagoras' theorem", "Shannon entropy"]); + /// let s2: Series = Series::new("Formula".into(), ["a²+b²=c²", "H=-Σ[P(x)log|P(x)|]"]); /// let df: DataFrame = DataFrame::new(vec![s1.clone(), s2.clone()])?; /// /// let mut iterator = df.iter(); @@ -600,19 +602,23 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Language" => &["Rust", "Python"], - /// "Designer" => &["Graydon Hoare", "Guido van Rossum"])?; + /// let df: DataFrame = df!("Language" => ["Rust", "Python"], + /// "Designer" => ["Graydon Hoare", "Guido van Rossum"])?; /// /// assert_eq!(df.get_column_names(), &["Language", "Designer"]); /// # Ok::<(), PolarsError>(()) /// ``` - pub fn get_column_names(&self) -> Vec<&str> { + pub fn get_column_names(&self) -> Vec<&PlSmallStr> { self.columns.iter().map(|s| s.name()).collect() } - /// Get the [`Vec`] representing the column names. - pub fn get_column_names_owned(&self) -> Vec { - self.columns.iter().map(|s| s.name().into()).collect() + /// Get the [`Vec`] representing the column names. + pub fn get_column_names_owned(&self) -> Vec { + self.columns.iter().map(|s| s.name().clone()).collect() + } + + pub fn get_column_names_str(&self) -> Vec<&str> { + self.columns.iter().map(|s| s.name().as_str()).collect() } /// Set the column names. @@ -620,24 +626,28 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let mut df: DataFrame = df!("Mathematical set" => &["ℕ", "ℤ", "𝔻", "ℚ", "ℝ", "ℂ"])?; - /// df.set_column_names(&["Set"])?; + /// let mut df: DataFrame = df!("Mathematical set" => ["ℕ", "ℤ", "𝔻", "ℚ", "ℝ", "ℂ"])?; + /// df.set_column_names(["Set"])?; /// /// assert_eq!(df.get_column_names(), &["Set"]); /// # Ok::<(), PolarsError>(()) /// ``` - pub fn set_column_names>(&mut self, names: &[S]) -> PolarsResult<()> { + pub fn set_column_names(&mut self, names: I) -> PolarsResult<()> + where + I: IntoIterator, + S: Into, + { + let names = names.into_iter().map(Into::into).collect::>(); + self._set_column_names_impl(names.as_slice()) + } + + fn _set_column_names_impl(&mut self, names: &[PlSmallStr]) -> PolarsResult<()> { polars_ensure!( names.len() == self.width(), ShapeMismatch: "{} column names provided for a DataFrame of width {}", names.len(), self.width() ); - let unique_names: AHashSet<&str, ahash::RandomState> = - AHashSet::from_iter(names.iter().map(|name| name.as_ref())); - polars_ensure!( - unique_names.len() == self.width(), - Duplicate: "duplicate column names found" - ); + ensure_names_unique(names, |s| s.as_str())?; let columns = mem::take(&mut self.columns); self.columns = columns @@ -645,7 +655,7 @@ impl DataFrame { .zip(names) .map(|(s, name)| { let mut s = s; - s.rename(name.as_ref()); + s.rename(name.clone()); s }) .collect(); @@ -658,8 +668,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let venus_air: DataFrame = df!("Element" => &["Carbon dioxide", "Nitrogen"], - /// "Fraction" => &[0.965, 0.035])?; + /// let venus_air: DataFrame = df!("Element" => ["Carbon dioxide", "Nitrogen"], + /// "Fraction" => [0.965, 0.035])?; /// /// assert_eq!(venus_air.dtypes(), &[DataType::String, DataType::Float64]); /// # Ok::<(), PolarsError>(()) @@ -682,11 +692,11 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let earth: DataFrame = df!("Surface type" => &["Water", "Land"], - /// "Fraction" => &[0.708, 0.292])?; + /// let earth: DataFrame = df!("Surface type" => ["Water", "Land"], + /// "Fraction" => [0.708, 0.292])?; /// - /// let f1: Field = Field::new("Surface type", DataType::String); - /// let f2: Field = Field::new("Fraction", DataType::Float64); + /// let f1: Field = Field::new("Surface type".into(), DataType::String); + /// let f2: Field = Field::new("Fraction".into(), DataType::Float64); /// /// assert_eq!(earth.fields(), &[f1, f2]); /// # Ok::<(), PolarsError>(()) @@ -705,9 +715,9 @@ impl DataFrame { /// ```rust /// # use polars_core::prelude::*; /// let df0: DataFrame = DataFrame::default(); - /// let df1: DataFrame = df!("1" => &[1, 2, 3, 4, 5])?; - /// let df2: DataFrame = df!("1" => &[1, 2, 3, 4, 5], - /// "2" => &[1, 2, 3, 4, 5])?; + /// let df1: DataFrame = df!("1" => [1, 2, 3, 4, 5])?; + /// let df2: DataFrame = df!("1" => [1, 2, 3, 4, 5], + /// "2" => [1, 2, 3, 4, 5])?; /// /// assert_eq!(df0.shape(), (0 ,0)); /// assert_eq!(df1.shape(), (5, 1)); @@ -728,9 +738,9 @@ impl DataFrame { /// ```rust /// # use polars_core::prelude::*; /// let df0: DataFrame = DataFrame::default(); - /// let df1: DataFrame = df!("Series 1" => &[0; 0])?; - /// let df2: DataFrame = df!("Series 1" => &[0; 0], - /// "Series 2" => &[0; 0])?; + /// let df1: DataFrame = df!("Series 1" => [0; 0])?; + /// let df2: DataFrame = df!("Series 1" => [0; 0], + /// "Series 2" => [0; 0])?; /// /// assert_eq!(df0.width(), 0); /// assert_eq!(df1.width(), 1); @@ -748,8 +758,8 @@ impl DataFrame { /// ```rust /// # use polars_core::prelude::*; /// let df0: DataFrame = DataFrame::default(); - /// let df1: DataFrame = df!("Currency" => &["€", "$"])?; - /// let df2: DataFrame = df!("Currency" => &["€", "$", "¥", "£", "₿"])?; + /// let df1: DataFrame = df!("Currency" => ["€", "$"])?; + /// let df2: DataFrame = df!("Currency" => ["€", "$", "¥", "£", "₿"])?; /// /// assert_eq!(df0.height(), 0); /// assert_eq!(df1.height(), 2); @@ -769,8 +779,8 @@ impl DataFrame { /// let df1: DataFrame = DataFrame::default(); /// assert!(df1.is_empty()); /// - /// let df2: DataFrame = df!("First name" => &["Forever"], - /// "Last name" => &["Alone"])?; + /// let df2: DataFrame = df!("First name" => ["Forever"], + /// "Last name" => ["Alone"])?; /// assert!(!df2.is_empty()); /// # Ok::<(), PolarsError>(()) /// ``` @@ -785,9 +795,9 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df1: DataFrame = df!("Element" => &["Copper", "Silver", "Gold"])?; - /// let s1: Series = Series::new("Proton", &[29, 47, 79]); - /// let s2: Series = Series::new("Electron", &[29, 47, 79]); + /// let df1: DataFrame = df!("Element" => ["Copper", "Silver", "Gold"])?; + /// let s1: Series = Series::new("Proton".into(), [29, 47, 79]); + /// let s2: Series = Series::new("Electron".into(), [29, 47, 79]); /// /// let df2: DataFrame = df1.hstack(&[s1, s2])?; /// assert_eq!(df2.shape(), (3, 3)); @@ -825,10 +835,10 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df1: DataFrame = df!("Element" => &["Copper", "Silver", "Gold"], - /// "Melting Point (K)" => &[1357.77, 1234.93, 1337.33])?; - /// let df2: DataFrame = df!("Element" => &["Platinum", "Palladium"], - /// "Melting Point (K)" => &[2041.4, 1828.05])?; + /// let df1: DataFrame = df!("Element" => ["Copper", "Silver", "Gold"], + /// "Melting Point (K)" => [1357.77, 1234.93, 1337.33])?; + /// let df2: DataFrame = df!("Element" => ["Platinum", "Palladium"], + /// "Melting Point (K)" => [2041.4, 1828.05])?; /// /// let df3: DataFrame = df1.vstack(&df2)?; /// @@ -871,10 +881,10 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let mut df1: DataFrame = df!("Element" => &["Copper", "Silver", "Gold"], - /// "Melting Point (K)" => &[1357.77, 1234.93, 1337.33])?; - /// let df2: DataFrame = df!("Element" => &["Platinum", "Palladium"], - /// "Melting Point (K)" => &[2041.4, 1828.05])?; + /// let mut df1: DataFrame = df!("Element" => ["Copper", "Silver", "Gold"], + /// "Melting Point (K)" => [1357.77, 1234.93, 1337.33])?; + /// let df2: DataFrame = df!("Element" => ["Platinum", "Palladium"], + /// "Melting Point (K)" => [2041.4, 1828.05])?; /// /// df1.vstack_mut(&df2)?; /// @@ -926,8 +936,13 @@ impl DataFrame { Ok(self) } - /// Does not check if schema is correct - pub(crate) fn vstack_mut_unchecked(&mut self, other: &DataFrame) { + /// Concatenate a [`DataFrame`] to this [`DataFrame`] + /// + /// If many `vstack` operations are done, it is recommended to call [`DataFrame::align_chunks`]. + /// + /// # Panics + /// Panics if the schema's don't match. + pub fn vstack_mut_unchecked(&mut self, other: &DataFrame) { self.columns .iter_mut() .zip(other.columns.iter()) @@ -973,14 +988,14 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let mut df: DataFrame = df!("Animal" => &["Tiger", "Lion", "Great auk"], - /// "IUCN" => &["Endangered", "Vulnerable", "Extinct"])?; + /// let mut df: DataFrame = df!("Animal" => ["Tiger", "Lion", "Great auk"], + /// "IUCN" => ["Endangered", "Vulnerable", "Extinct"])?; /// /// let s1: PolarsResult = df.drop_in_place("Average weight"); /// assert!(s1.is_err()); /// /// let s2: Series = df.drop_in_place("Animal")?; - /// assert_eq!(s2, Series::new("Animal", &["Tiger", "Lion", "Great auk"])); + /// assert_eq!(s2, Series::new("Animal".into(), &["Tiger", "Lion", "Great auk"])); /// # Ok::<(), PolarsError>(()) /// ``` pub fn drop_in_place(&mut self, name: &str) -> PolarsResult { @@ -1016,22 +1031,26 @@ impl DataFrame { /// | Malta | 32.7 | /// +---------+---------------------+ /// ``` - pub fn drop_nulls>(&self, subset: Option<&[S]>) -> PolarsResult { - let selected_series; - - let mut iter = match subset { - Some(cols) => { - selected_series = self.select_series(cols)?; - selected_series.iter() - }, - None => self.columns.iter(), - }; + pub fn drop_nulls(&self, subset: Option<&[S]>) -> PolarsResult + where + for<'a> &'a S: Into, + { + if let Some(v) = subset { + let v = self.select_series(v)?; + self._drop_nulls_impl(v.as_slice()) + } else { + self._drop_nulls_impl(self.columns.as_slice()) + } + } + fn _drop_nulls_impl(&self, subset: &[Series]) -> PolarsResult { // fast path for no nulls in df - if iter.clone().all(|s| !s.has_nulls()) { + if subset.iter().all(|s| !s.has_nulls()) { return Ok(self.clone()); } + let mut iter = subset.iter(); + let mask = iter .next() .ok_or_else(|| polars_err!(NoData: "no data to drop nulls from"))?; @@ -1051,7 +1070,7 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df1: DataFrame = df!("Ray type" => &["α", "β", "X", "γ"])?; + /// let df1: DataFrame = df!("Ray type" => ["α", "β", "X", "γ"])?; /// let df2: DataFrame = df1.drop("Ray type")?; /// /// assert!(df2.is_empty()); @@ -1071,19 +1090,23 @@ impl DataFrame { } /// Drop columns that are in `names`. - pub fn drop_many>(&self, names: &[S]) -> Self { - let names: PlHashSet<_> = names.iter().map(|s| s.as_ref()).collect(); + pub fn drop_many(&self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + let names: PlHashSet = names.into_iter().map(|s| s.into()).collect(); self.drop_many_amortized(&names) } /// Drop columns that are in `names` without allocating a [`HashSet`](std::collections::HashSet). - pub fn drop_many_amortized(&self, names: &PlHashSet<&str>) -> DataFrame { + pub fn drop_many_amortized(&self, names: &PlHashSet) -> DataFrame { if names.is_empty() { return self.clone(); } let mut new_cols = Vec::with_capacity(self.columns.len().saturating_sub(names.len())); self.columns.iter().for_each(|s| { - if !names.contains(&s.name()) { + if !names.contains(s.name()) { new_cols.push(s.clone()) } }); @@ -1114,12 +1137,12 @@ impl DataFrame { column: S, ) -> PolarsResult<&mut Self> { let series = column.into_series(); - self.check_already_present(series.name())?; + self.check_already_present(series.name().as_str())?; self.insert_column_no_name_check(index, series) } fn add_column_by_search(&mut self, series: Series) -> PolarsResult<()> { - if let Some(idx) = self.get_column_index(series.name()) { + if let Some(idx) = self.get_column_index(series.name().as_str()) { self.replace_column(idx, series)?; } else { self.columns.push(series); @@ -1161,13 +1184,20 @@ impl DataFrame { /// # Safety /// The caller must ensure `column.len() == self.height()` . pub unsafe fn with_column_unchecked(&mut self, column: Series) -> &mut Self { - self.get_columns_mut().push(column); - self + #[cfg(debug_assertions)] + { + return self.with_column(column).unwrap(); + } + #[cfg(not(debug_assertions))] + { + self.get_columns_mut().push(column); + self + } } fn add_column_by_schema(&mut self, s: Series, schema: &Schema) -> PolarsResult<()> { let name = s.name(); - if let Some((idx, _, _)) = schema.get_full(name) { + if let Some((idx, _, _)) = schema.get_full(name.as_str()) { // schema is incorrect fallback to search if self.columns.get(idx).map(|s| s.name()) != Some(name) { self.add_column_by_search(s)?; @@ -1184,7 +1214,7 @@ impl DataFrame { for (i, s) in columns.into_iter().enumerate() { // we need to branch here // because users can add multiple columns with the same name - if i == 0 || schema.get(s.name()).is_some() { + if i == 0 || schema.get(s.name().as_str()).is_some() { self.with_column_and_schema(s, schema)?; } else { self.with_column(s.clone())?; @@ -1254,11 +1284,11 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Star" => &["Sun", "Betelgeuse", "Sirius A", "Sirius B"], - /// "Absolute magnitude" => &[4.83, -5.85, 1.42, 11.18])?; + /// let df: DataFrame = df!("Star" => ["Sun", "Betelgeuse", "Sirius A", "Sirius B"], + /// "Absolute magnitude" => [4.83, -5.85, 1.42, 11.18])?; /// /// let s1: Option<&Series> = df.select_at_idx(0); - /// let s2: Series = Series::new("Star", &["Sun", "Betelgeuse", "Sirius A", "Sirius B"]); + /// let s2: Series = Series::new("Star".into(), ["Sun", "Betelgeuse", "Sirius A", "Sirius B"]); /// /// assert_eq!(s1, Some(&s2)); /// # Ok::<(), PolarsError>(()) @@ -1282,12 +1312,12 @@ impl DataFrame { /// ```rust /// # use polars_core::prelude::*; /// let df = df! { - /// "0" => &[0, 0, 0], - /// "1" => &[1, 1, 1], - /// "2" => &[2, 2, 2] + /// "0" => [0, 0, 0], + /// "1" => [1, 1, 1], + /// "2" => [2, 2, 2] /// }?; /// - /// assert!(df.select(&["0", "1"])?.equals(&df.select_by_range(0..=1)?)); + /// assert!(df.select(["0", "1"])?.equals(&df.select_by_range(0..=1)?)); /// assert!(df.equals(&df.select_by_range(..)?)); /// # Ok::<(), PolarsError>(()) /// ``` @@ -1342,10 +1372,10 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Name" => &["Player 1", "Player 2", "Player 3"], - /// "Health" => &[100, 200, 500], - /// "Mana" => &[250, 100, 0], - /// "Strength" => &[30, 150, 300])?; + /// let df: DataFrame = df!("Name" => ["Player 1", "Player 2", "Player 3"], + /// "Health" => [100, 200, 500], + /// "Mana" => [250, 100, 0], + /// "Strength" => [30, 150, 300])?; /// /// assert_eq!(df.get_column_index("Name"), Some(0)); /// assert_eq!(df.get_column_index("Health"), Some(1)); @@ -1355,7 +1385,7 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn get_column_index(&self, name: &str) -> Option { - self.columns.iter().position(|s| s.name() == name) + self.columns.iter().position(|s| s.name().as_str() == name) } /// Get column index of a [`Series`] by name. @@ -1370,8 +1400,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let s1: Series = Series::new("Password", &["123456", "[]B$u$g$s$B#u#n#n#y[]{}"]); - /// let s2: Series = Series::new("Robustness", &["Weak", "Strong"]); + /// let s1: Series = Series::new("Password".into(), ["123456", "[]B$u$g$s$B#u#n#n#y[]{}"]); + /// let s2: Series = Series::new("Robustness".into(), ["Weak", "Strong"]); /// let df: DataFrame = DataFrame::new(vec![s1.clone(), s2])?; /// /// assert_eq!(df.column("Password")?, &s1); @@ -1388,9 +1418,9 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Latin name" => &["Oncorhynchus kisutch", "Salmo salar"], - /// "Max weight (kg)" => &[16.0, 35.89])?; - /// let sv: Vec<&Series> = df.columns(&["Latin name", "Max weight (kg)"])?; + /// let df: DataFrame = df!("Latin name" => ["Oncorhynchus kisutch", "Salmo salar"], + /// "Max weight (kg)" => [16.0, 35.89])?; + /// let sv: Vec<&Series> = df.columns(["Latin name", "Max weight (kg)"])?; /// /// assert_eq!(&df[0], sv[0]); /// assert_eq!(&df[1], sv[1]); @@ -1420,21 +1450,18 @@ impl DataFrame { pub fn select(&self, selection: I) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { - let cols = selection - .into_iter() - .map(|s| SmartString::from(s.as_ref())) - .collect::>(); - self._select_impl(&cols) + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self._select_impl(cols.as_slice()) } - pub fn _select_impl(&self, cols: &[SmartString]) -> PolarsResult { - self.select_check_duplicates(cols)?; + pub fn _select_impl(&self, cols: &[PlSmallStr]) -> PolarsResult { + ensure_names_unique(cols, |s| s.as_str())?; self._select_impl_unchecked(cols) } - pub fn _select_impl_unchecked(&self, cols: &[SmartString]) -> PolarsResult { + pub fn _select_impl_unchecked(&self, cols: &[PlSmallStr]) -> PolarsResult { let selected = self.select_series_impl(cols)?; Ok(unsafe { DataFrame::new_no_checks(selected) }) } @@ -1443,13 +1470,10 @@ impl DataFrame { pub fn select_with_schema(&self, selection: I, schema: &SchemaRef) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { - let cols = selection - .into_iter() - .map(|s| SmartString::from(s.as_ref())) - .collect::>(); - self.select_with_schema_impl(&cols, schema, true) + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self._select_with_schema_impl(&cols, schema, true) } /// Select with a known schema. This doesn't check for duplicates. @@ -1460,23 +1484,20 @@ impl DataFrame { ) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { - let cols = selection - .into_iter() - .map(|s| SmartString::from(s.as_ref())) - .collect::>(); - self.select_with_schema_impl(&cols, schema, false) + let cols = selection.into_iter().map(|s| s.into()).collect::>(); + self._select_with_schema_impl(&cols, schema, false) } - fn select_with_schema_impl( + pub fn _select_with_schema_impl( &self, - cols: &[SmartString], + cols: &[PlSmallStr], schema: &Schema, check_duplicates: bool, ) -> PolarsResult { if check_duplicates { - self.select_check_duplicates(cols)?; + ensure_names_unique(cols, |s| s.as_str())?; } let selected = self.select_series_impl_with_schema(cols, schema)?; Ok(unsafe { DataFrame::new_no_checks(selected) }) @@ -1485,12 +1506,12 @@ impl DataFrame { /// A non generic implementation to reduce compiler bloat. fn select_series_impl_with_schema( &self, - cols: &[SmartString], + cols: &[PlSmallStr], schema: &Schema, ) -> PolarsResult> { cols.iter() .map(|name| { - let index = schema.try_get_full(name)?.0; + let index = schema.try_get_full(name.as_str())?.0; Ok(self.columns[index].clone()) }) .collect() @@ -1499,47 +1520,34 @@ impl DataFrame { pub fn select_physical(&self, selection: I) -> PolarsResult where I: IntoIterator, - S: AsRef, + S: Into, { - let cols = selection - .into_iter() - .map(|s| SmartString::from(s.as_ref())) - .collect::>(); + let cols = selection.into_iter().map(|s| s.into()).collect::>(); self.select_physical_impl(&cols) } - fn select_physical_impl(&self, cols: &[SmartString]) -> PolarsResult { - self.select_check_duplicates(cols)?; + fn select_physical_impl(&self, cols: &[PlSmallStr]) -> PolarsResult { + ensure_names_unique(cols, |s| s.as_str())?; let selected = self.select_series_physical_impl(cols)?; Ok(unsafe { DataFrame::new_no_checks(selected) }) } - fn select_check_duplicates(&self, cols: &[SmartString]) -> PolarsResult<()> { - let mut names = PlHashSet::with_capacity(cols.len()); - for name in cols { - if !names.insert(name.as_str()) { - polars_bail!(duplicate = name); - } - } - Ok(()) - } - /// Select column(s) from this [`DataFrame`] and return them into a [`Vec`]. /// /// # Example /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Name" => &["Methane", "Ethane", "Propane"], - /// "Carbon" => &[1, 2, 3], - /// "Hydrogen" => &[4, 6, 8])?; - /// let sv: Vec = df.select_series(&["Carbon", "Hydrogen"])?; + /// let df: DataFrame = df!("Name" => ["Methane", "Ethane", "Propane"], + /// "Carbon" => [1, 2, 3], + /// "Hydrogen" => [4, 6, 8])?; + /// let sv: Vec = df.select_series(["Carbon", "Hydrogen"])?; /// /// assert_eq!(df["Carbon"], sv[0]); /// assert_eq!(df["Hydrogen"], sv[1]); /// # Ok::<(), PolarsError>(()) /// ``` - pub fn select_series(&self, selection: impl IntoVec) -> PolarsResult> { + pub fn select_series(&self, selection: impl IntoVec) -> PolarsResult> { let cols = selection.into_vec(); self.select_series_impl(&cols) } @@ -1548,12 +1556,12 @@ impl DataFrame { self.columns .iter() .enumerate() - .map(|(i, s)| (s.name(), i)) + .map(|(i, s)| (s.name().as_str(), i)) .collect() } /// A non generic implementation to reduce compiler bloat. - fn select_series_physical_impl(&self, cols: &[SmartString]) -> PolarsResult> { + fn select_series_physical_impl(&self, cols: &[PlSmallStr]) -> PolarsResult> { let selected = if cols.len() > 1 && self.columns.len() > 10 { let name_to_idx = self._names_to_idx_map(); cols.iter() @@ -1570,7 +1578,10 @@ impl DataFrame { .collect::>>()? } else { cols.iter() - .map(|c| self.column(c).map(|s| s.to_physical_repr().into_owned())) + .map(|c| { + self.column(c.as_str()) + .map(|s| s.to_physical_repr().into_owned()) + }) .collect::>>()? }; @@ -1578,7 +1589,7 @@ impl DataFrame { } /// A non generic implementation to reduce compiler bloat. - fn select_series_impl(&self, cols: &[SmartString]) -> PolarsResult> { + fn select_series_impl(&self, cols: &[PlSmallStr]) -> PolarsResult> { let selected = if cols.len() > 1 && self.columns.len() > 10 { // we hash, because there are user that having millions of columns. // # https://github.com/pola-rs/polars/issues/1023 @@ -1594,7 +1605,7 @@ impl DataFrame { .collect::>>()? } else { cols.iter() - .map(|c| self.column(c).cloned()) + .map(|c| self.column(c.as_str()).cloned()) .collect::>>()? }; @@ -1639,7 +1650,7 @@ impl DataFrame { /// ``` /// # use polars_core::prelude::*; /// fn example(df: &DataFrame) -> PolarsResult { - /// let idx = IdxCa::new("idx", &[0, 1, 9]); + /// let idx = IdxCa::new("idx".into(), [0, 1, 9]); /// df.take(&idx) /// } /// ``` @@ -1691,19 +1702,20 @@ impl DataFrame { /// fn example(df: &mut DataFrame) -> PolarsResult<&mut DataFrame> { /// let original_name = "foo"; /// let new_name = "bar"; - /// df.rename(original_name, new_name) + /// df.rename(original_name, new_name.into()) /// } /// ``` - pub fn rename(&mut self, column: &str, name: &str) -> PolarsResult<&mut Self> { + pub fn rename(&mut self, column: &str, name: PlSmallStr) -> PolarsResult<&mut Self> { + if column == name.as_str() { + return Ok(self); + } + polars_ensure!( + self.columns.iter().all(|c| c.name() != &name), + Duplicate: "column rename attempted with already existing name \"{name}\"" + ); self.select_mut(column) .ok_or_else(|| polars_err!(col_not_found = column)) .map(|s| s.rename(name))?; - let unique_names: AHashSet<&str, ahash::RandomState> = - AHashSet::from_iter(self.columns.iter().map(|s| s.name())); - polars_ensure!( - unique_names.len() == self.width(), - Duplicate: "duplicate column names found" - ); Ok(self) } @@ -1712,7 +1724,7 @@ impl DataFrame { /// See [`DataFrame::sort`] for more instruction. pub fn sort_in_place( &mut self, - by: impl IntoVec, + by: impl IntoVec, sort_options: SortMultipleOptions, ) -> PolarsResult<&mut Self> { let by_column = self.select_series(by)?; @@ -1784,7 +1796,7 @@ impl DataFrame { // fast path for a frame with a single series // no need to compute the sort indices and then take by these indices // simply sort and return as frame - if df.width() == 1 && df.check_name_to_idx(s.name()).is_ok() { + if df.width() == 1 && df.check_name_to_idx(s.name().as_str()).is_ok() { let mut out = s.sort_with(options)?; if let Some((offset, len)) = slice { out = out.slice(offset, len); @@ -1849,7 +1861,7 @@ impl DataFrame { /// # use polars_core::prelude::*; /// fn sort_by_multiple_columns_with_specific_order(df: &DataFrame) -> PolarsResult { /// df.sort( - /// &["sepal_width", "sepal_length"], + /// ["sepal_width", "sepal_length"], /// SortMultipleOptions::new() /// .with_order_descending_multi([false, true]) /// ) @@ -1860,7 +1872,7 @@ impl DataFrame { /// Also see [`DataFrame::sort_in_place`]. pub fn sort( &self, - by: impl IntoVec, + by: impl IntoVec, sort_options: SortMultipleOptions, ) -> PolarsResult { let mut df = self.clone(); @@ -1874,9 +1886,9 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let mut df: DataFrame = df!("Country" => &["United States", "China"], - /// "Area (km²)" => &[9_833_520, 9_596_961])?; - /// let s: Series = Series::new("Country", &["USA", "PRC"]); + /// let mut df: DataFrame = df!("Country" => ["United States", "China"], + /// "Area (km²)" => [9_833_520, 9_596_961])?; + /// let s: Series = Series::new("Country".into(), ["USA", "PRC"]); /// /// assert!(df.replace("Nation", s.clone()).is_err()); /// assert!(df.replace("Country", s).is_ok()); @@ -1891,7 +1903,7 @@ impl DataFrame { /// of the `Series` passed to this method. pub fn replace_or_add( &mut self, - column: &str, + column: PlSmallStr, new_col: S, ) -> PolarsResult<&mut Self> { let mut new_col = new_col.into_series(); @@ -1905,8 +1917,8 @@ impl DataFrame { /// /// ```ignored /// # use polars_core::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg"]); - /// let s1 = Series::new("ascii", &[70, 79, 79]); + /// let s0 = Series::new("foo".into(), ["ham", "spam", "egg"]); + /// let s1 = Series::new("ascii".into(), [70, 79, 79]); /// let mut df = DataFrame::new(vec![s0, s1])?; /// /// // Add 32 to get lowercase ascii values @@ -1942,8 +1954,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg"]); - /// let s1 = Series::new("names", &["Jean", "Claude", "van"]); + /// let s0 = Series::new("foo".into(), ["ham", "spam", "egg"]); + /// let s1 = Series::new("names".into(), ["Jean", "Claude", "van"]); /// let mut df = DataFrame::new(vec![s0, s1])?; /// /// fn str_to_len(str_val: &Series) -> Series { @@ -1992,8 +2004,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg"]); - /// let s1 = Series::new("ascii", &[70, 79, 79]); + /// let s0 = Series::new("foo".into(), ["ham", "spam", "egg"]); + /// let s1 = Series::new("ascii".into(), [70, 79, 79]); /// let mut df = DataFrame::new(vec![s0, s1])?; /// /// // Add 32 to get lowercase ascii values @@ -2028,7 +2040,7 @@ impl DataFrame { idx, width ) })?; - let name = col.name().to_string(); + let name = col.name().clone(); let new_col = f(col).into_series(); match new_col.len() { 1 => { @@ -2048,7 +2060,7 @@ impl DataFrame { // make sure the name remains the same after applying the closure unsafe { let col = self.columns.get_unchecked_mut(idx); - col.rename(&name); + col.rename(name); } Ok(self) } @@ -2062,8 +2074,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg", "bacon", "quack"]); - /// let s1 = Series::new("values", &[1, 2, 3, 4, 5]); + /// let s0 = Series::new("foo".into(), ["ham", "spam", "egg", "bacon", "quack"]); + /// let s1 = Series::new("values".into(), [1, 2, 3, 4, 5]); /// let mut df = DataFrame::new(vec![s0, s1])?; /// /// let idx = vec![0, 1, 4]; @@ -2105,14 +2117,14 @@ impl DataFrame { idx, width ) })?; - let name = col.name().to_string(); + let name = col.name().clone(); let _ = mem::replace(col, f(col).map(|s| s.into_series())?); // make sure the name remains the same after applying the closure unsafe { let col = self.columns.get_unchecked_mut(idx); - col.rename(&name); + col.rename(name); } Ok(self) } @@ -2126,8 +2138,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let s0 = Series::new("foo", &["ham", "spam", "egg", "bacon", "quack"]); - /// let s1 = Series::new("values", &[1, 2, 3, 4, 5]); + /// let s0 = Series::new("foo".into(), ["ham", "spam", "egg", "bacon", "quack"]); + /// let s1 = Series::new("values".into(), [1, 2, 3, 4, 5]); /// let mut df = DataFrame::new(vec![s0, s1])?; /// /// // create a mask @@ -2174,8 +2186,8 @@ impl DataFrame { /// /// ```rust /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Fruit" => &["Apple", "Grape", "Grape", "Fig", "Fig"], - /// "Color" => &["Green", "Red", "White", "White", "Red"])?; + /// let df: DataFrame = df!("Fruit" => ["Apple", "Grape", "Grape", "Fig", "Fig"], + /// "Color" => ["Green", "Red", "White", "White", "Red"])?; /// let sl: DataFrame = df.slice(2, 3); /// /// assert_eq!(sl.shape(), (3, 2)); @@ -2255,10 +2267,10 @@ impl DataFrame { /// ```rust /// # use polars_core::prelude::*; /// let countries: DataFrame = - /// df!("Rank by GDP (2021)" => &[1, 2, 3, 4, 5], - /// "Continent" => &["North America", "Asia", "Asia", "Europe", "Europe"], - /// "Country" => &["United States", "China", "Japan", "Germany", "United Kingdom"], - /// "Capital" => &["Washington", "Beijing", "Tokyo", "Berlin", "London"])?; + /// df!("Rank by GDP (2021)" => [1, 2, 3, 4, 5], + /// "Continent" => ["North America", "Asia", "Asia", "Europe", "Europe"], + /// "Country" => ["United States", "China", "Japan", "Germany", "United Kingdom"], + /// "Capital" => ["Washington", "Beijing", "Tokyo", "Berlin", "London"])?; /// assert_eq!(countries.shape(), (5, 4)); /// /// println!("{}", countries.head(Some(3))); @@ -2298,9 +2310,9 @@ impl DataFrame { /// ```rust /// # use polars_core::prelude::*; /// let countries: DataFrame = - /// df!("Rank (2021)" => &[105, 106, 107, 108, 109], - /// "Apple Price (€/kg)" => &[0.75, 0.70, 0.70, 0.65, 0.52], - /// "Country" => &["Kosovo", "Moldova", "North Macedonia", "Syria", "Turkey"])?; + /// df!("Rank (2021)" => [105, 106, 107, 108, 109], + /// "Apple Price (€/kg)" => [0.75, 0.70, 0.70, 0.65, 0.52], + /// "Country" => ["Kosovo", "Moldova", "North Macedonia", "Syria", "Turkey"])?; /// assert_eq!(countries.shape(), (5, 3)); /// /// println!("{}", countries.tail(Some(2))); @@ -2391,7 +2403,6 @@ impl DataFrame { #[must_use] pub fn shift(&self, periods: i64) -> Self { let col = self._apply_columns_par(&|s| s.shift(periods)); - unsafe { DataFrame::new_no_checks(col) } } @@ -2656,32 +2667,39 @@ impl DataFrame { keep: UniqueKeepStrategy, slice: Option<(i64, usize)>, ) -> PolarsResult { - self.unique_impl(true, subset, keep, slice) + self.unique_impl( + true, + subset.map(|v| v.iter().map(|x| PlSmallStr::from_str(x.as_str())).collect()), + keep, + slice, + ) } /// Unstable distinct. See [`DataFrame::unique_stable`]. #[cfg(feature = "algorithm_group_by")] - pub fn unique( + pub fn unique( &self, subset: Option<&[String]>, keep: UniqueKeepStrategy, slice: Option<(i64, usize)>, ) -> PolarsResult { - self.unique_impl(false, subset, keep, slice) + self.unique_impl( + false, + subset.map(|v| v.iter().map(|x| PlSmallStr::from_str(x.as_str())).collect()), + keep, + slice, + ) } #[cfg(feature = "algorithm_group_by")] pub fn unique_impl( &self, maintain_order: bool, - subset: Option<&[String]>, + subset: Option>, keep: UniqueKeepStrategy, slice: Option<(i64, usize)>, ) -> PolarsResult { - let names = match &subset { - Some(s) => s.iter().map(|s| &**s).collect(), - None => self.get_column_names(), - }; + let names = subset.unwrap_or_else(|| self.get_column_names_owned()); let mut df = self.clone(); // take on multiple chunks is terrible df.as_single_chunk_par(); @@ -2749,8 +2767,8 @@ impl DataFrame { /// /// ```no_run /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Company" => &["Apple", "Microsoft"], - /// "ISIN" => &["US0378331005", "US5949181045"])?; + /// let df: DataFrame = df!("Company" => ["Apple", "Microsoft"], + /// "ISIN" => ["US0378331005", "US5949181045"])?; /// let ca: ChunkedArray = df.is_unique()?; /// /// assert!(ca.all()); @@ -2758,7 +2776,7 @@ impl DataFrame { /// ``` #[cfg(feature = "algorithm_group_by")] pub fn is_unique(&self) -> PolarsResult { - let gb = self.group_by(self.get_column_names())?; + let gb = self.group_by(self.get_column_names_owned())?; let groups = gb.take_groups(); Ok(is_unique_helper( groups, @@ -2774,8 +2792,8 @@ impl DataFrame { /// /// ```no_run /// # use polars_core::prelude::*; - /// let df: DataFrame = df!("Company" => &["Alphabet", "Alphabet"], - /// "ISIN" => &["US02079K3059", "US02079K1079"])?; + /// let df: DataFrame = df!("Company" => ["Alphabet", "Alphabet"], + /// "ISIN" => ["US02079K3059", "US02079K1079"])?; /// let ca: ChunkedArray = df.is_duplicated()?; /// /// assert!(!ca.all()); @@ -2783,7 +2801,7 @@ impl DataFrame { /// ``` #[cfg(feature = "algorithm_group_by")] pub fn is_duplicated(&self) -> PolarsResult { - let gb = self.group_by(self.get_column_names())?; + let gb = self.group_by(self.get_column_names_owned())?; let groups = gb.take_groups(); Ok(is_unique_helper( groups, @@ -2799,7 +2817,7 @@ impl DataFrame { let cols = self .columns .iter() - .map(|s| Series::new(s.name(), &[s.null_count() as IdxSize])) + .map(|s| Series::new(s.name().clone(), [s.null_count() as IdxSize])) .collect(); unsafe { Self::new_no_checks(cols) } } @@ -2808,7 +2826,7 @@ impl DataFrame { #[cfg(feature = "row_hash")] pub fn hash_rows( &mut self, - hasher_builder: Option, + hasher_builder: Option, ) -> PolarsResult { let dfs = split_df(self, POOL.current_num_threads(), false); let (cas, _) = _df_rows_to_hashes_threaded_vertical(&dfs, hasher_builder)?; @@ -2816,7 +2834,7 @@ impl DataFrame { let mut iter = cas.into_iter(); let mut acc_ca = iter.next().unwrap(); for ca in iter { - acc_ca.append(&ca); + acc_ca.append(&ca)?; } Ok(acc_ca.rechunk()) } @@ -2865,7 +2883,7 @@ impl DataFrame { } } } - let mut ca = IdxCa::mmap_slice("", idx); + let mut ca = IdxCa::mmap_slice(PlSmallStr::EMPTY, idx); ca.set_sorted_flag(sorted); self.take_unchecked_impl(&ca, allow_threads) } @@ -2874,21 +2892,21 @@ impl DataFrame { #[doc(hidden)] pub fn _partition_by_impl( &self, - cols: &[String], + cols: &[PlSmallStr], stable: bool, include_key: bool, ) -> PolarsResult> { let groups = if stable { - self.group_by_stable(cols)?.take_groups() + self.group_by_stable(cols.iter().cloned())?.take_groups() } else { - self.group_by(cols)?.take_groups() + self.group_by(cols.iter().cloned())?.take_groups() }; // drop key columns prior to calculation if requested let df = if include_key { self.clone() } else { - self.drop_many(cols) + self.drop_many(cols.iter().cloned()) }; // don't parallelize this @@ -2919,37 +2937,47 @@ impl DataFrame { /// Split into multiple DataFrames partitioned by groups #[cfg(feature = "partition_by")] - pub fn partition_by( - &self, - cols: impl IntoVec, - include_key: bool, - ) -> PolarsResult> { - let cols = cols.into_vec(); - self._partition_by_impl(&cols, false, include_key) + pub fn partition_by(&self, cols: I, include_key: bool) -> PolarsResult> + where + I: IntoIterator, + S: Into, + { + let cols = cols + .into_iter() + .map(Into::into) + .collect::>(); + self._partition_by_impl(cols.as_slice(), false, include_key) } /// Split into multiple DataFrames partitioned by groups /// Order of the groups are maintained. #[cfg(feature = "partition_by")] - pub fn partition_by_stable( + pub fn partition_by_stable( &self, - cols: impl IntoVec, + cols: I, include_key: bool, - ) -> PolarsResult> { - let cols = cols.into_vec(); - self._partition_by_impl(&cols, true, include_key) + ) -> PolarsResult> + where + I: IntoIterator, + S: Into, + { + let cols = cols + .into_iter() + .map(Into::into) + .collect::>(); + self._partition_by_impl(cols.as_slice(), true, include_key) } /// Unnest the given `Struct` columns. This means that the fields of the `Struct` type will be /// inserted as columns. #[cfg(feature = "dtype-struct")] - pub fn unnest>(&self, cols: I) -> PolarsResult { + pub fn unnest>(&self, cols: I) -> PolarsResult { let cols = cols.into_vec(); self.unnest_impl(cols.into_iter().collect()) } #[cfg(feature = "dtype-struct")] - fn unnest_impl(&self, cols: PlHashSet) -> PolarsResult { + fn unnest_impl(&self, cols: PlHashSet) -> PolarsResult { let mut new_cols = Vec::with_capacity(std::cmp::min(self.width() * 2, self.width() + 128)); let mut count = 0; for s in &self.columns { @@ -2967,7 +2995,7 @@ impl DataFrame { let schema = self.schema(); for col in cols { let _ = schema - .get(&col) + .get(col.as_str()) .ok_or_else(|| polars_err!(col_not_found = col))?; } } @@ -3066,8 +3094,8 @@ mod test { use super::*; fn create_frame() -> DataFrame { - let s0 = Series::new("days", [0, 1, 2].as_ref()); - let s1 = Series::new("temp", [22.1, 19.9, 7.].as_ref()); + let s0 = Series::new("days".into(), [0, 1, 2].as_ref()); + let s1 = Series::new("temp".into(), [22.1, 19.9, 7.].as_ref()); DataFrame::new(vec![s0, s1]).unwrap() } @@ -3075,7 +3103,7 @@ mod test { #[cfg_attr(miri, ignore)] fn test_recordbatch_iterator() { let df = df!( - "foo" => &[1, 2, 3, 4, 5] + "foo" => [1, 2, 3, 4, 5] ) .unwrap(); let mut iter = df.iter_chunks(CompatLevel::newest(), false); @@ -3095,7 +3123,7 @@ mod test { fn test_filter_broadcast_on_string_col() { let col_name = "some_col"; let v = vec!["test".to_string()]; - let s0 = Series::new(col_name, v); + let s0 = Series::new(PlSmallStr::from_str(col_name), v); let mut df = DataFrame::new(vec![s0]).unwrap(); df = df @@ -3107,10 +3135,10 @@ mod test { #[test] #[cfg_attr(miri, ignore)] fn test_filter_broadcast_on_list_col() { - let s1 = Series::new("", &[true, false, true]); + let s1 = Series::new(PlSmallStr::EMPTY, [true, false, true]); let ll: ListChunked = [&s1].iter().copied().collect(); - let mask = BooleanChunked::from_slice("", &[false]); + let mask = BooleanChunked::from_slice(PlSmallStr::EMPTY, &[false]); let new = ll.filter(&mask).unwrap(); assert_eq!(new.chunks.len(), 1); @@ -3138,8 +3166,8 @@ mod test { )?; // Create a series with multiple chunks - let mut s = Series::new("foo", 0..2); - let s2 = Series::new("bar", 0..1); + let mut s = Series::new("foo".into(), 0..2); + let s2 = Series::new("bar".into(), 0..1); s.append(&s2)?; // Append series to frame @@ -3153,12 +3181,16 @@ mod test { #[test] fn test_duplicate_column() { let mut df = df! { - "foo" => &[1, 2, 3] + "foo" => [1, 2, 3] } .unwrap(); // check if column is replaced - assert!(df.with_column(Series::new("foo", &[1, 2, 3])).is_ok()); - assert!(df.with_column(Series::new("bar", &[1, 2, 3])).is_ok()); + assert!(df + .with_column(Series::new("foo".into(), &[1, 2, 3])) + .is_ok()); + assert!(df + .with_column(Series::new("bar".into(), &[1, 2, 3])) + .is_ok()); assert!(df.column("bar").is_ok()) } @@ -3203,9 +3235,9 @@ mod test { #[cfg(feature = "zip_with")] #[cfg_attr(miri, ignore)] fn test_horizontal_agg() { - let a = Series::new("a", &[1, 2, 6]); - let b = Series::new("b", &[Some(1), None, None]); - let c = Series::new("c", &[Some(4), None, Some(3)]); + let a = Series::new("a".into(), [1, 2, 6]); + let b = Series::new("b".into(), [Some(1), None, None]); + let c = Series::new("c".into(), [Some(4), None, Some(3)]); let df = DataFrame::new(vec![a, b, c]).unwrap(); assert_eq!( @@ -3246,7 +3278,7 @@ mod test { )?; // check that the new column is "c" and not "bar". - df.replace_or_add("c", Series::new("bar", [1, 2, 3]))?; + df.replace_or_add("c".into(), Series::new("bar".into(), [1, 2, 3]))?; assert_eq!(df.get_column_names(), &["a", "b", "c"]); Ok(()) @@ -3261,13 +3293,13 @@ mod test { // has got columns, but no rows let mut df = base.clear(); - let out = df.with_column(Series::new("c", [1]))?; + let out = df.with_column(Series::new("c".into(), [1]))?; assert_eq!(out.shape(), (0, 3)); assert!(out.iter().all(|s| s.len() == 0)); // no columns base.columns = vec![]; - let out = base.with_column(Series::new("c", [1]))?; + let out = base.with_column(Series::new("c".into(), [1]))?; assert_eq!(out.shape(), (1, 1)); Ok(()) diff --git a/crates/polars-core/src/frame/row/av_buffer.rs b/crates/polars-core/src/frame/row/av_buffer.rs index cedf4da1799f..608d6ec820af 100644 --- a/crates/polars-core/src/frame/row/av_buffer.rs +++ b/crates/polars-core/src/frame/row/av_buffer.rs @@ -1,8 +1,8 @@ #[cfg(feature = "dtype-struct")] +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "dtype-struct")] use polars_utils::slice::GetSaferUnchecked; use polars_utils::unreachable_unchecked_release; -#[cfg(feature = "dtype-struct")] -use smartstring::alias::String as SmartString; use super::*; use crate::chunked_array::builder::NullChunkedBuilder; @@ -67,7 +67,7 @@ impl<'a> AnyValueBuffer<'a> { (Float32(builder), val) => builder.append_value(val.extract()?), (Float64(builder), val) => builder.append_value(val.extract()?), (String(builder), AnyValue::String(v)) => builder.append_value(v), - (String(builder), AnyValue::StringOwned(v)) => builder.append_value(v), + (String(builder), AnyValue::StringOwned(v)) => builder.append_value(v.as_str()), (String(builder), AnyValue::Null) => builder.append_null(), #[cfg(feature = "dtype-i8")] (Int8(builder), AnyValue::Null) => builder.append_null(), @@ -151,39 +151,39 @@ impl<'a> AnyValueBuffer<'a> { use AnyValueBuffer::*; match self { Boolean(b) => { - let mut new = BooleanChunkedBuilder::new(b.field.name(), capacity); + let mut new = BooleanChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Int32(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Int64(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, UInt32(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, UInt64(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-date")] Date(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_date().into_series() }, #[cfg(feature = "dtype-datetime")] Datetime(b, tu, tz) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); let tz = if capacity > 0 { tz.clone() @@ -194,62 +194,63 @@ impl<'a> AnyValueBuffer<'a> { }, #[cfg(feature = "dtype-duration")] Duration(b, tu) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_duration(*tu).into_series() }, #[cfg(feature = "dtype-time")] Time(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_time().into_series() }, Float32(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Float64(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, String(b) => { - let mut new = StringChunkedBuilder::new(b.field.name(), capacity); + let mut new = StringChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-i8")] Int8(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-i16")] Int16(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-u8")] UInt8(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-u16")] UInt16(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Null(b) => { - let mut new = NullChunkedBuilder::new(b.field.name(), 0); + let mut new = NullChunkedBuilder::new(b.field.name().clone(), 0); std::mem::swap(&mut new, b); new.finish().into_series() }, All(dtype, vals) => { - let out = Series::from_any_values_and_dtype("", vals, dtype, false).unwrap(); + let out = Series::from_any_values_and_dtype(PlSmallStr::EMPTY, vals, dtype, false) + .unwrap(); let mut new = Vec::with_capacity(capacity); std::mem::swap(&mut new, vals); out @@ -272,33 +273,41 @@ impl From<(&DataType, usize)> for AnyValueBuffer<'_> { let (dt, len) = a; use DataType::*; match dt { - Boolean => AnyValueBuffer::Boolean(BooleanChunkedBuilder::new("", len)), - Int32 => AnyValueBuffer::Int32(PrimitiveChunkedBuilder::new("", len)), - Int64 => AnyValueBuffer::Int64(PrimitiveChunkedBuilder::new("", len)), - UInt32 => AnyValueBuffer::UInt32(PrimitiveChunkedBuilder::new("", len)), - UInt64 => AnyValueBuffer::UInt64(PrimitiveChunkedBuilder::new("", len)), + Boolean => AnyValueBuffer::Boolean(BooleanChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Int32 => AnyValueBuffer::Int32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Int64 => AnyValueBuffer::Int64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + UInt32 => AnyValueBuffer::UInt32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + UInt64 => AnyValueBuffer::UInt64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), #[cfg(feature = "dtype-i8")] - Int8 => AnyValueBuffer::Int8(PrimitiveChunkedBuilder::new("", len)), + Int8 => AnyValueBuffer::Int8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), #[cfg(feature = "dtype-i16")] - Int16 => AnyValueBuffer::Int16(PrimitiveChunkedBuilder::new("", len)), + Int16 => AnyValueBuffer::Int16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), #[cfg(feature = "dtype-u8")] - UInt8 => AnyValueBuffer::UInt8(PrimitiveChunkedBuilder::new("", len)), + UInt8 => AnyValueBuffer::UInt8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), #[cfg(feature = "dtype-u16")] - UInt16 => AnyValueBuffer::UInt16(PrimitiveChunkedBuilder::new("", len)), + UInt16 => AnyValueBuffer::UInt16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), #[cfg(feature = "dtype-date")] - Date => AnyValueBuffer::Date(PrimitiveChunkedBuilder::new("", len)), + Date => AnyValueBuffer::Date(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), #[cfg(feature = "dtype-datetime")] - Datetime(tu, tz) => { - AnyValueBuffer::Datetime(PrimitiveChunkedBuilder::new("", len), *tu, tz.clone()) - }, + Datetime(tu, tz) => AnyValueBuffer::Datetime( + PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len), + *tu, + tz.clone(), + ), #[cfg(feature = "dtype-duration")] - Duration(tu) => AnyValueBuffer::Duration(PrimitiveChunkedBuilder::new("", len), *tu), + Duration(tu) => { + AnyValueBuffer::Duration(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len), *tu) + }, #[cfg(feature = "dtype-time")] - Time => AnyValueBuffer::Time(PrimitiveChunkedBuilder::new("", len)), - Float32 => AnyValueBuffer::Float32(PrimitiveChunkedBuilder::new("", len)), - Float64 => AnyValueBuffer::Float64(PrimitiveChunkedBuilder::new("", len)), - String => AnyValueBuffer::String(StringChunkedBuilder::new("", len)), - Null => AnyValueBuffer::Null(NullChunkedBuilder::new("", 0)), + Time => AnyValueBuffer::Time(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Float32 => { + AnyValueBuffer::Float32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Float64 => { + AnyValueBuffer::Float64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + String => AnyValueBuffer::String(StringChunkedBuilder::new(PlSmallStr::EMPTY, len)), + Null => AnyValueBuffer::Null(NullChunkedBuilder::new(PlSmallStr::EMPTY, 0)), // Struct and List can be recursive so use AnyValues for that dt => AnyValueBuffer::All(dt.clone(), Vec::with_capacity(len)), } @@ -326,7 +335,7 @@ pub enum AnyValueBufferTrusted<'a> { String(StringChunkedBuilder), #[cfg(feature = "dtype-struct")] // not the trusted variant! - Struct(Vec<(AnyValueBuffer<'a>, SmartString)>), + Struct(Vec<(AnyValueBuffer<'a>, PlSmallStr)>), Null(NullChunkedBuilder), All(DataType, Vec>), } @@ -471,7 +480,7 @@ impl<'a> AnyValueBufferTrusted<'a> { let AnyValue::StringOwned(v) = val else { unreachable_unchecked_release!() }; - builder.append_value(v) + builder.append_value(v.as_str()) }, #[cfg(feature = "dtype-struct")] Struct(builders) => { @@ -542,66 +551,66 @@ impl<'a> AnyValueBufferTrusted<'a> { use AnyValueBufferTrusted::*; match self { Boolean(b) => { - let mut new = BooleanChunkedBuilder::new(b.field.name(), capacity); + let mut new = BooleanChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Int32(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Int64(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, UInt32(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, UInt64(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Float32(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, Float64(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, String(b) => { - let mut new = StringChunkedBuilder::new(b.field.name(), capacity); + let mut new = StringChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-i8")] Int8(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-i16")] Int16(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-u8")] UInt8(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, #[cfg(feature = "dtype-u16")] UInt16(b) => { - let mut new = PrimitiveChunkedBuilder::new(b.field.name(), capacity); + let mut new = PrimitiveChunkedBuilder::new(b.field.name().clone(), capacity); std::mem::swap(&mut new, b); new.finish().into_series() }, @@ -611,21 +620,24 @@ impl<'a> AnyValueBufferTrusted<'a> { .iter_mut() .map(|(b, name)| { let mut s = b.reset(capacity); - s.rename(name.as_str()); + s.rename(name.clone()); s }) .collect::>(); - StructChunked::from_series("", &v).unwrap().into_series() + StructChunked::from_series(PlSmallStr::EMPTY, &v) + .unwrap() + .into_series() }, Null(b) => { - let mut new = NullChunkedBuilder::new(b.field.name(), 0); + let mut new = NullChunkedBuilder::new(b.field.name().clone(), 0); std::mem::swap(&mut new, b); new.finish().into_series() }, All(dtype, vals) => { let mut swap_vals = Vec::with_capacity(capacity); std::mem::swap(vals, &mut swap_vals); - Series::from_any_values_and_dtype("", &swap_vals, dtype, false).unwrap() + Series::from_any_values_and_dtype(PlSmallStr::EMPTY, &swap_vals, dtype, false) + .unwrap() }, } } @@ -640,28 +652,52 @@ impl From<(&DataType, usize)> for AnyValueBufferTrusted<'_> { let (dt, len) = a; use DataType::*; match dt { - Boolean => AnyValueBufferTrusted::Boolean(BooleanChunkedBuilder::new("", len)), - Int32 => AnyValueBufferTrusted::Int32(PrimitiveChunkedBuilder::new("", len)), - Int64 => AnyValueBufferTrusted::Int64(PrimitiveChunkedBuilder::new("", len)), - UInt32 => AnyValueBufferTrusted::UInt32(PrimitiveChunkedBuilder::new("", len)), - UInt64 => AnyValueBufferTrusted::UInt64(PrimitiveChunkedBuilder::new("", len)), + Boolean => { + AnyValueBufferTrusted::Boolean(BooleanChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Int32 => { + AnyValueBufferTrusted::Int32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Int64 => { + AnyValueBufferTrusted::Int64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + UInt32 => { + AnyValueBufferTrusted::UInt32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + UInt64 => { + AnyValueBufferTrusted::UInt64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, #[cfg(feature = "dtype-i8")] - Int8 => AnyValueBufferTrusted::Int8(PrimitiveChunkedBuilder::new("", len)), + Int8 => { + AnyValueBufferTrusted::Int8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, #[cfg(feature = "dtype-i16")] - Int16 => AnyValueBufferTrusted::Int16(PrimitiveChunkedBuilder::new("", len)), + Int16 => { + AnyValueBufferTrusted::Int16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, #[cfg(feature = "dtype-u8")] - UInt8 => AnyValueBufferTrusted::UInt8(PrimitiveChunkedBuilder::new("", len)), + UInt8 => { + AnyValueBufferTrusted::UInt8(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, #[cfg(feature = "dtype-u16")] - UInt16 => AnyValueBufferTrusted::UInt16(PrimitiveChunkedBuilder::new("", len)), - Float32 => AnyValueBufferTrusted::Float32(PrimitiveChunkedBuilder::new("", len)), - Float64 => AnyValueBufferTrusted::Float64(PrimitiveChunkedBuilder::new("", len)), - String => AnyValueBufferTrusted::String(StringChunkedBuilder::new("", len)), + UInt16 => { + AnyValueBufferTrusted::UInt16(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Float32 => { + AnyValueBufferTrusted::Float32(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + Float64 => { + AnyValueBufferTrusted::Float64(PrimitiveChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, + String => { + AnyValueBufferTrusted::String(StringChunkedBuilder::new(PlSmallStr::EMPTY, len)) + }, #[cfg(feature = "dtype-struct")] Struct(fields) => { let buffers = fields .iter() .map(|field| { - let dtype = field.data_type().to_physical(); + let dtype = field.dtype().to_physical(); let buffer: AnyValueBuffer = (&dtype, len).into(); (buffer, field.name.clone()) }) diff --git a/crates/polars-core/src/frame/row/dataframe.rs b/crates/polars-core/src/frame/row/dataframe.rs index f9e60cebcd0e..4a40a9ed6d6f 100644 --- a/crates/polars-core/src/frame/row/dataframe.rs +++ b/crates/polars-core/src/frame/row/dataframe.rs @@ -56,7 +56,7 @@ impl DataFrame { let capacity = rows.size_hint().0; let mut buffers: Vec<_> = schema - .iter_dtypes() + .iter_values() .map(|dtype| { let buf: AnyValueBuffer = (dtype, capacity).into(); buf @@ -79,9 +79,9 @@ impl DataFrame { // if the schema adds a column not in the rows, we // fill it with nulls if s.is_empty() { - Series::full_null(name, expected_len, s.dtype()) + Series::full_null(name.clone(), expected_len, s.dtype()) } else { - s.rename(name); + s.rename(name.clone()); s } }) @@ -98,7 +98,7 @@ impl DataFrame { let capacity = rows.size_hint().0; let mut buffers: Vec<_> = schema - .iter_dtypes() + .iter_values() .map(|dtype| { let buf: AnyValueBuffer = (dtype, capacity).into(); buf @@ -121,9 +121,9 @@ impl DataFrame { // if the schema adds a column not in the rows, we // fill it with nulls if s.is_empty() { - Series::full_null(name, expected_len, s.dtype()) + Series::full_null(name.clone(), expected_len, s.dtype()) } else { - s.rename(name); + s.rename(name.clone()); s } }) @@ -136,7 +136,7 @@ impl DataFrame { pub fn from_rows(rows: &[Row]) -> PolarsResult { let schema = rows_to_schema_first_non_null(rows, Some(50))?; let has_nulls = schema - .iter_dtypes() + .iter_values() .any(|dtype| matches!(dtype, DataType::Null)); polars_ensure!( !has_nulls, ComputeError: "unable to infer row types because of null values" diff --git a/crates/polars-core/src/frame/row/mod.rs b/crates/polars-core/src/frame/row/mod.rs index e9cf92ffad13..44e445b0874e 100644 --- a/crates/polars-core/src/frame/row/mod.rs +++ b/crates/polars-core/src/frame/row/mod.rs @@ -10,6 +10,7 @@ use std::hint::unreachable_unchecked; use arrow::bitmap::Bitmap; pub use av_buffer::*; +use polars_utils::format_pl_smallstr; #[cfg(feature = "object")] use polars_utils::total_ord::TotalHash; use rayon::prelude::*; @@ -96,10 +97,10 @@ impl<'a> Row<'a> { } } -type Tracker = PlIndexMap>; +type Tracker = PlIndexMap>; pub fn infer_schema( - iter: impl Iterator)>>, + iter: impl Iterator, impl Into)>>, infer_schema_length: usize, ) -> Schema { let mut values: Tracker = Tracker::default(); @@ -108,25 +109,25 @@ pub fn infer_schema( let max_infer = std::cmp::min(len, infer_schema_length); for inner in iter.take(max_infer) { for (key, value) in inner { - add_or_insert(&mut values, &key, value.into()); + add_or_insert(&mut values, key.into(), value.into()); } } Schema::from_iter(resolve_fields(values)) } -fn add_or_insert(values: &mut Tracker, key: &str, data_type: DataType) { - if data_type == DataType::Null { +fn add_or_insert(values: &mut Tracker, key: PlSmallStr, dtype: DataType) { + if dtype == DataType::Null { return; } - if values.contains_key(key) { - let x = values.get_mut(key).unwrap(); - x.insert(data_type); + if values.contains_key(&key) { + let x = values.get_mut(&key).unwrap(); + x.insert(dtype); } else { // create hashset and add value type let mut hs = PlHashSet::new(); - hs.insert(data_type); - values.insert(key.to_string(), hs); + hs.insert(dtype); + values.insert(key, hs); } } @@ -134,13 +135,13 @@ fn resolve_fields(spec: Tracker) -> Vec { spec.iter() .map(|(k, hs)| { let v: Vec<&DataType> = hs.iter().collect(); - Field::new(k, coerce_data_type(&v)) + Field::new(k.clone(), coerce_dtype(&v)) }) .collect() } /// Coerces a slice of datatypes into a single supertype. -pub fn coerce_data_type>(datatypes: &[A]) -> DataType { +pub fn coerce_dtype>(datatypes: &[A]) -> DataType { use DataType::*; let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow()); @@ -206,7 +207,7 @@ pub fn rows_to_schema_first_non_null( for row in rows.iter().take(max_infer).skip(1) { // for i in 1..max_infer { let nulls: Vec<_> = schema - .iter_dtypes() + .iter_values() .enumerate() .filter_map(|(i, dtype)| { // double check struct and list types types @@ -237,7 +238,7 @@ pub fn rows_to_schema_first_non_null( impl<'a> From<&AnyValue<'a>> for Field { fn from(val: &AnyValue<'a>) -> Self { - Field::new("", val.into()) + Field::new(PlSmallStr::EMPTY, val.into()) } } @@ -248,7 +249,7 @@ impl From<&Row<'_>> for Schema { .enumerate() .map(|(i, av)| { let dtype = av.into(); - Field::new(format!("column_{i}").as_ref(), dtype) + Field::new(format_pl_smallstr!("column_{i}"), dtype) }) .collect() } diff --git a/crates/polars-core/src/frame/row/transpose.rs b/crates/polars-core/src/frame/row/transpose.rs index 7ad4bc4f1fef..1984a085116f 100644 --- a/crates/polars-core/src/frame/row/transpose.rs +++ b/crates/polars-core/src/frame/row/transpose.rs @@ -8,8 +8,8 @@ impl DataFrame { pub(crate) fn transpose_from_dtype( &self, dtype: &DataType, - keep_names_as: Option<&str>, - names_out: &[String], + keep_names_as: Option, + names_out: &[PlSmallStr], ) -> PolarsResult { let new_width = self.height(); let new_height = self.width(); @@ -18,7 +18,13 @@ impl DataFrame { None => Vec::::with_capacity(new_width), Some(name) => { let mut tmp = Vec::::with_capacity(new_width + 1); - tmp.push(StringChunked::new(name, self.get_column_names()).into()); + tmp.push( + StringChunked::from_iter_values( + name, + self.get_column_names_owned().into_iter(), + ) + .into(), + ); tmp }, }; @@ -74,7 +80,7 @@ impl DataFrame { cols_t.extend(buffers.into_iter().zip(names_out).map(|(buf, name)| { // SAFETY: we are casting back to the supertype let mut s = unsafe { buf.into_series().cast_unchecked(dtype).unwrap() }; - s.rename(name); + s.rename(name.clone()); s })); }, @@ -82,26 +88,43 @@ impl DataFrame { Ok(unsafe { DataFrame::new_no_checks(cols_t) }) } - /// Transpose a DataFrame. This is a very expensive operation. pub fn transpose( &mut self, keep_names_as: Option<&str>, new_col_names: Option>>, + ) -> PolarsResult { + let new_col_names = match new_col_names { + None => None, + Some(Either::Left(v)) => Some(Either::Left(v.into())), + Some(Either::Right(v)) => Some(Either::Right( + v.into_iter().map(Into::into).collect::>(), + )), + }; + + self.transpose_impl(keep_names_as, new_col_names) + } + /// Transpose a DataFrame. This is a very expensive operation. + pub fn transpose_impl( + &mut self, + keep_names_as: Option<&str>, + new_col_names: Option>>, ) -> PolarsResult { // We must iterate columns as [`AnyValue`], so we must be contiguous. self.as_single_chunk_par(); let mut df = Cow::Borrowed(self); // Can't use self because we might drop a name column let names_out = match new_col_names { - None => (0..self.height()).map(|i| format!("column_{i}")).collect(), + None => (0..self.height()) + .map(|i| format_pl_smallstr!("column_{i}")) + .collect(), Some(cn) => match cn { Either::Left(name) => { - let new_names = self.column(&name).and_then(|x| x.str())?; + let new_names = self.column(name.as_str()).and_then(|x| x.str())?; polars_ensure!(new_names.null_count() == 0, ComputeError: "Column with new names can't have null values"); - df = Cow::Owned(self.drop(&name)?); + df = Cow::Owned(self.drop(name.as_str())?); new_names .into_no_null_iter() - .map(|s| s.to_owned()) + .map(PlSmallStr::from_str) .collect() }, Either::Right(names) => { @@ -141,7 +164,7 @@ impl DataFrame { }, _ => {}, } - df.transpose_from_dtype(&dtype, keep_names_as, &names_out) + df.transpose_from_dtype(&dtype, keep_names_as.map(PlSmallStr::from_str), &names_out) } } @@ -159,8 +182,11 @@ unsafe fn add_value( // This just fills a pre-allocated mutable series vector, which may have a name column. // Nothing is returned and the actual DataFrame is constructed above. -pub(super) fn numeric_transpose(cols: &[Series], names_out: &[String], cols_t: &mut Vec) -where +pub(super) fn numeric_transpose( + cols: &[Series], + names_out: &[PlSmallStr], + cols_t: &mut Vec, +) where T: PolarsNumericType, //S: AsRef, ChunkedArray: IntoSeries, @@ -251,7 +277,7 @@ where values.into(), validity, ); - ChunkedArray::with_chunk(name.as_str(), arr).into_series() + ChunkedArray::with_chunk(name.clone(), arr).into_series() }); POOL.install(|| cols_t.par_extend(par_iter)); } diff --git a/crates/polars-core/src/functions.rs b/crates/polars-core/src/functions.rs index 6ca5548f000f..57cbee3a01dc 100644 --- a/crates/polars-core/src/functions.rs +++ b/crates/polars-core/src/functions.rs @@ -19,8 +19,8 @@ pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { for df in dfs { df.get_columns().iter().for_each(|s| { - let name = s.name(); - if column_names.insert(name) { + let name = s.name().clone(); + if column_names.insert(name.clone()) { schema.push((name, s.dtype())) } }); @@ -33,9 +33,9 @@ pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { let mut columns = Vec::with_capacity(schema.len()); for (name, dtype) in &schema { - match df.column(name).ok() { + match df.column(name.as_str()).ok() { Some(s) => columns.push(s.clone()), - None => columns.push(Series::full_null(name, height, dtype)), + None => columns.push(Series::full_null(name.clone(), height, dtype)), } } unsafe { DataFrame::new_no_checks(columns) } diff --git a/crates/polars-core/src/hashing/identity.rs b/crates/polars-core/src/hashing/identity.rs index 7554395ac50c..e917291f1586 100644 --- a/crates/polars-core/src/hashing/identity.rs +++ b/crates/polars-core/src/hashing/identity.rs @@ -36,6 +36,7 @@ pub type IdBuildHasher = BuildHasherDefault; #[derive(Debug)] /// Contains an idx of a row in a DataFrame and the precomputed hash of that row. +/// /// That hash still needs to be used to create another hash to be able to resize hashmaps without /// accidental quadratic behavior. So do not use an Identity function! pub struct IdxHash { diff --git a/crates/polars-core/src/hashing/mod.rs b/crates/polars-core/src/hashing/mod.rs index 5e1b891a9702..8f966eb2f317 100644 --- a/crates/polars-core/src/hashing/mod.rs +++ b/crates/polars-core/src/hashing/mod.rs @@ -3,7 +3,6 @@ pub(crate) mod vector_hasher; use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; -use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; pub use identity::*; @@ -40,6 +39,7 @@ pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usiz } /// Populate a multiple key hashmap with row indexes. +/// /// Instead of the keys (which could be very large), the row indexes are stored. /// To check if a row is equal the original DataFrame is also passed as ref. /// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 5bdf8d126bd6..277c1c009ba0 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -17,15 +17,13 @@ const MULTIPLE: u64 = 6364136223846793005; pub trait VecHash { /// Compute the hash for all values in the array. - /// - /// This currently only works with the AHash RandomState hasher builder. - fn vec_hash(&self, _random_state: RandomState, _buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, _random_state: PlRandomState, _buf: &mut Vec) -> PolarsResult<()> { polars_bail!(un_impl = vec_hash); } fn vec_hash_combine( &self, - _random_state: RandomState, + _random_state: PlRandomState, _hashes: &mut [u64], ) -> PolarsResult<()> { polars_bail!(un_impl = vec_hash_combine); @@ -37,14 +35,14 @@ pub(crate) const fn folded_multiply(s: u64, by: u64) -> u64 { ((result & 0xffff_ffff_ffff_ffff) as u64) ^ ((result >> 64) as u64) } -pub(crate) fn get_null_hash_value(random_state: &RandomState) -> u64 { +pub(crate) fn get_null_hash_value(random_state: &PlRandomState) -> u64 { // we just start with a large prime number and hash that twice // to get a constant hash value for null/None let first = random_state.hash_one(3188347919usize); random_state.hash_one(first) } -fn insert_null_hash(chunks: &[ArrayRef], random_state: RandomState, buf: &mut Vec) { +fn insert_null_hash(chunks: &[ArrayRef], random_state: PlRandomState, buf: &mut Vec) { let null_h = get_null_hash_value(&random_state); let hashes = buf.as_mut_slice(); @@ -64,7 +62,7 @@ fn insert_null_hash(chunks: &[ArrayRef], random_state: RandomState, buf: &mut Ve }); } -fn numeric_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) +fn numeric_vec_hash(ca: &ChunkedArray, random_state: PlRandomState, buf: &mut Vec) where T: PolarsNumericType, T::Native: TotalHash + ToTotalOrd, @@ -93,8 +91,11 @@ where insert_null_hash(&ca.chunks, random_state, buf) } -fn numeric_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) -where +fn numeric_vec_hash_combine( + ca: &ChunkedArray, + random_state: PlRandomState, + hashes: &mut [u64], +) where T: PolarsNumericType, T::Native: TotalHash + ToTotalOrd, ::TotalOrdItem: Hash, @@ -139,14 +140,18 @@ where macro_rules! vec_hash_numeric { ($ca:ident) => { impl VecHash for $ca { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash( + &self, + random_state: PlRandomState, + buf: &mut Vec, + ) -> PolarsResult<()> { numeric_vec_hash(self, random_state, buf); Ok(()) } fn vec_hash_combine( &self, - random_state: RandomState, + random_state: PlRandomState, hashes: &mut [u64], ) -> PolarsResult<()> { numeric_vec_hash_combine(self, random_state, hashes); @@ -170,19 +175,23 @@ vec_hash_numeric!(Float32Chunked); vec_hash_numeric!(Int128Chunked); impl VecHash for StringChunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.as_binary().vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + random_state: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.as_binary().vec_hash_combine(random_state, hashes)?; Ok(()) } } // used in polars-pipe -pub fn _hash_binary_array(arr: &BinaryArray, random_state: RandomState, buf: &mut Vec) { +pub fn _hash_binary_array(arr: &BinaryArray, random_state: PlRandomState, buf: &mut Vec) { let null_h = get_null_hash_value(&random_state); if arr.null_count() == 0 { // use the null_hash as seed to get a hash determined by `random_state` that is passed @@ -195,7 +204,7 @@ pub fn _hash_binary_array(arr: &BinaryArray, random_state: RandomState, buf } } -fn hash_binview_array(arr: &BinaryViewArray, random_state: RandomState, buf: &mut Vec) { +fn hash_binview_array(arr: &BinaryViewArray, random_state: PlRandomState, buf: &mut Vec) { let null_h = get_null_hash_value(&random_state); if arr.null_count() == 0 { // use the null_hash as seed to get a hash determined by `random_state` that is passed @@ -209,7 +218,7 @@ fn hash_binview_array(arr: &BinaryViewArray, random_state: RandomState, buf: &mu } impl VecHash for BinaryChunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { buf.clear(); buf.reserve(self.len()); self.downcast_iter() @@ -217,7 +226,11 @@ impl VecHash for BinaryChunked { Ok(()) } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + random_state: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { let null_h = get_null_hash_value(&random_state); let mut offset = 0; @@ -254,7 +267,7 @@ impl VecHash for BinaryChunked { } impl VecHash for BinaryOffsetChunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { buf.clear(); buf.reserve(self.len()); self.downcast_iter() @@ -262,7 +275,11 @@ impl VecHash for BinaryOffsetChunked { Ok(()) } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + random_state: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { let null_h = get_null_hash_value(&random_state); let mut offset = 0; @@ -299,14 +316,18 @@ impl VecHash for BinaryOffsetChunked { } impl VecHash for NullChunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { let null_h = get_null_hash_value(&random_state); buf.clear(); buf.resize(self.len(), null_h); Ok(()) } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + random_state: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { let null_h = get_null_hash_value(&random_state); hashes .iter_mut() @@ -315,7 +336,7 @@ impl VecHash for NullChunked { } } impl VecHash for BooleanChunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { buf.clear(); buf.reserve(self.len()); let true_h = random_state.hash_one(true); @@ -335,7 +356,11 @@ impl VecHash for BooleanChunked { Ok(()) } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + random_state: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { let true_h = random_state.hash_one(true); let false_h = random_state.hash_one(false); let null_h = get_null_hash_value(&random_state); @@ -382,7 +407,7 @@ impl VecHash for ObjectChunked where T: PolarsObject, { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to // splits that have no nulls and splits that have nulls. Then one array is hashed with @@ -398,7 +423,11 @@ where Ok(()) } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + random_state: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.apply_to_slice( |opt_v, h| { let hashed = random_state.hash_one(opt_v); @@ -412,8 +441,8 @@ where pub fn _df_rows_to_hashes_threaded_vertical( keys: &[DataFrame], - hasher_builder: Option, -) -> PolarsResult<(Vec, RandomState)> { + hasher_builder: Option, +) -> PolarsResult<(Vec, PlRandomState)> { let hasher_builder = hasher_builder.unwrap_or_default(); let hashes = POOL.install(|| { @@ -422,7 +451,7 @@ pub fn _df_rows_to_hashes_threaded_vertical( let hb = hasher_builder.clone(); let mut hashes = vec![]; series_to_hashes(df.get_columns(), Some(hb), &mut hashes)?; - Ok(UInt64Chunked::from_vec("", hashes)) + Ok(UInt64Chunked::from_vec(PlSmallStr::EMPTY, hashes)) }) .collect::>>() })?; @@ -431,9 +460,9 @@ pub fn _df_rows_to_hashes_threaded_vertical( pub(crate) fn series_to_hashes( keys: &[Series], - build_hasher: Option, + build_hasher: Option, hashes: &mut Vec, -) -> PolarsResult { +) -> PolarsResult { let build_hasher = build_hasher.unwrap_or_default(); let mut iter = keys.iter(); diff --git a/crates/polars-core/src/named_from.rs b/crates/polars-core/src/named_from.rs index 8bcc17cef853..4d5714e4e517 100644 --- a/crates/polars-core/src/named_from.rs +++ b/crates/polars-core/src/named_from.rs @@ -14,18 +14,18 @@ use crate::prelude::*; pub trait NamedFrom { /// Initialize by name and values. - fn new(name: &str, _: T) -> Self; + fn new(name: PlSmallStr, _: T) -> Self; } pub trait NamedFromOwned { /// Initialize by name and values. - fn from_vec(name: &str, _: T) -> Self; + fn from_vec(name: PlSmallStr, _: T) -> Self; } macro_rules! impl_named_from_owned { ($type:ty, $polars_type:ident) => { impl NamedFromOwned<$type> for Series { - fn from_vec(name: &str, v: $type) -> Self { + fn from_vec(name: PlSmallStr, v: $type) -> Self { ChunkedArray::<$polars_type>::from_vec(name, v).into_series() } } @@ -52,12 +52,12 @@ impl_named_from_owned!(Vec, Float64Type); macro_rules! impl_named_from { ($type:ty, $polars_type:ident, $method:ident) => { impl> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { ChunkedArray::<$polars_type>::$method(name, v.as_ref()).into_series() } } impl> NamedFrom for ChunkedArray<$polars_type> { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { ChunkedArray::<$polars_type>::$method(name, v.as_ref()) } } @@ -106,14 +106,14 @@ impl_named_from!([Option], Float64Type, from_slice_options); macro_rules! impl_named_from_range { ($range:ty, $polars_type:ident) => { impl NamedFrom<$range, $polars_type> for ChunkedArray<$polars_type> { - fn new(name: &str, range: $range) -> Self { + fn new(name: PlSmallStr, range: $range) -> Self { let values = range.collect::>(); ChunkedArray::<$polars_type>::from_vec(name, values) } } impl NamedFrom<$range, $polars_type> for Series { - fn new(name: &str, range: $range) -> Self { + fn new(name: PlSmallStr, range: $range) -> Self { ChunkedArray::new(name, range).into_series() } } @@ -125,7 +125,7 @@ impl_named_from_range!(std::ops::Range, UInt64Type); impl_named_from_range!(std::ops::Range, UInt32Type); impl> NamedFrom for Series { - fn new(name: &str, s: T) -> Self { + fn new(name: PlSmallStr, s: T) -> Self { let series_slice = s.as_ref(); let list_cap = series_slice.len(); @@ -155,7 +155,7 @@ impl> NamedFrom for Series { } impl]>> NamedFrom]> for Series { - fn new(name: &str, s: T) -> Self { + fn new(name: PlSmallStr, s: T) -> Self { let series_slice = s.as_ref(); let values_cap = series_slice.iter().fold(0, |acc, opt_s| { acc + opt_s.as_ref().map(|s| s.len()).unwrap_or(0) @@ -173,13 +173,13 @@ impl]>> NamedFrom]> for Series { } } impl<'a, T: AsRef<[&'a str]>> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_slice(name, v.as_ref()).into_series() } } impl NamedFrom<&Series, str> for Series { - fn new(name: &str, s: &Series) -> Self { + fn new(name: PlSmallStr, s: &Series) -> Self { let mut s = s.clone(); s.rename(name); s @@ -187,44 +187,44 @@ impl NamedFrom<&Series, str> for Series { } impl<'a, T: AsRef<[&'a str]>> NamedFrom for StringChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_slice(name, v.as_ref()) } } impl<'a, T: AsRef<[Option<&'a str>]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_slice_options(name, v.as_ref()).into_series() } } impl<'a, T: AsRef<[Option<&'a str>]>> NamedFrom]> for StringChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_slice_options(name, v.as_ref()) } } impl<'a, T: AsRef<[Cow<'a, str>]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) .into_series() } } impl<'a, T: AsRef<[Cow<'a, str>]>> NamedFrom]> for StringChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) } } impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::new(name, v).into_series() } } impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for StringChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { StringChunked::from_iter_options( name, v.as_ref() @@ -235,44 +235,44 @@ impl<'a, T: AsRef<[Option>]>> NamedFrom>]> } impl<'a, T: AsRef<[&'a [u8]]>> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_slice(name, v.as_ref()).into_series() } } impl<'a, T: AsRef<[&'a [u8]]>> NamedFrom for BinaryChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_slice(name, v.as_ref()) } } impl<'a, T: AsRef<[Option<&'a [u8]>]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_slice_options(name, v.as_ref()).into_series() } } impl<'a, T: AsRef<[Option<&'a [u8]>]>> NamedFrom]> for BinaryChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_slice_options(name, v.as_ref()) } } impl<'a, T: AsRef<[Cow<'a, [u8]>]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) .into_series() } } impl<'a, T: AsRef<[Cow<'a, [u8]>]>> NamedFrom]> for BinaryChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_iter_values(name, v.as_ref().iter().map(|value| value.as_ref())) } } impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::new(name, v).into_series() } } @@ -280,7 +280,7 @@ impl<'a, T: AsRef<[Option>]>> NamedFrom>] impl<'a, T: AsRef<[Option>]>> NamedFrom>]> for BinaryChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { BinaryChunked::from_iter_options( name, v.as_ref() @@ -292,35 +292,35 @@ impl<'a, T: AsRef<[Option>]>> NamedFrom>] #[cfg(feature = "dtype-date")] impl> NamedFrom for DateChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DateChunked::from_naive_date(name, v.as_ref().iter().copied()) } } #[cfg(feature = "dtype-date")] impl> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DateChunked::new(name, v).into_series() } } #[cfg(feature = "dtype-date")] impl]>> NamedFrom]> for DateChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DateChunked::from_naive_date_options(name, v.as_ref().iter().copied()) } } #[cfg(feature = "dtype-date")] impl]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DateChunked::new(name, v).into_series() } } #[cfg(feature = "dtype-datetime")] impl> NamedFrom for DatetimeChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DatetimeChunked::from_naive_datetime( name, v.as_ref().iter().copied(), @@ -331,14 +331,14 @@ impl> NamedFrom for DatetimeChunke #[cfg(feature = "dtype-datetime")] impl> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DatetimeChunked::new(name, v).into_series() } } #[cfg(feature = "dtype-datetime")] impl]>> NamedFrom]> for DatetimeChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DatetimeChunked::from_naive_datetime_options( name, v.as_ref().iter().copied(), @@ -349,21 +349,21 @@ impl]>> NamedFrom]> fo #[cfg(feature = "dtype-datetime")] impl]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DatetimeChunked::new(name, v).into_series() } } #[cfg(feature = "dtype-duration")] impl> NamedFrom for DurationChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DurationChunked::from_duration(name, v.as_ref().iter().copied(), TimeUnit::Nanoseconds) } } #[cfg(feature = "dtype-duration")] impl> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DurationChunked::new(name, v).into_series() } } @@ -372,7 +372,7 @@ impl> NamedFrom for Series { impl]>> NamedFrom]> for DurationChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DurationChunked::from_duration_options( name, v.as_ref().iter().copied(), @@ -383,49 +383,49 @@ impl]>> NamedFrom]> #[cfg(feature = "dtype-duration")] impl]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { DurationChunked::new(name, v).into_series() } } #[cfg(feature = "dtype-time")] impl> NamedFrom for TimeChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { TimeChunked::from_naive_time(name, v.as_ref().iter().copied()) } } #[cfg(feature = "dtype-time")] impl> NamedFrom for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { TimeChunked::new(name, v).into_series() } } #[cfg(feature = "dtype-time")] impl]>> NamedFrom]> for TimeChunked { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { TimeChunked::from_naive_time_options(name, v.as_ref().iter().copied()) } } #[cfg(feature = "dtype-time")] impl]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { + fn new(name: PlSmallStr, v: T) -> Self { TimeChunked::new(name, v).into_series() } } #[cfg(feature = "object")] impl NamedFrom<&[T], &[T]> for ObjectChunked { - fn new(name: &str, v: &[T]) -> Self { + fn new(name: PlSmallStr, v: &[T]) -> Self { ObjectChunked::from_slice(name, v) } } #[cfg(feature = "object")] impl]>> NamedFrom]> for ObjectChunked { - fn new(name: &str, v: S) -> Self { + fn new(name: PlSmallStr, v: S) -> Self { ObjectChunked::from_slice_options(name, v.as_ref()) } } @@ -433,14 +433,14 @@ impl]>> NamedFrom]> for Objec impl ChunkedArray { /// Specialization that prevents an allocation /// prefer this over ChunkedArray::new when you have a `Vec` and no null values. - pub fn new_vec(name: &str, v: Vec) -> Self { + pub fn new_vec(name: PlSmallStr, v: Vec) -> Self { ChunkedArray::from_vec(name, v) } } /// For any [`ChunkedArray`] and [`Series`] impl NamedFrom for Series { - fn new(name: &str, t: T) -> Self { + fn new(name: PlSmallStr, t: T) -> Self { let mut s = t.into_series(); s.rename(name); s @@ -474,9 +474,9 @@ mod test { #[test] fn build_series_from_empty_series_vec() { - let empty_series = Series::new("test", Vec::::new()); + let empty_series = Series::new("test".into(), Vec::::new()); assert_eq!(empty_series.len(), 0); assert_eq!(*empty_series.dtype(), DataType::Null); - assert_eq!(empty_series.name(), "test"); + assert_eq!(empty_series.name().as_str(), "test"); } } diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index a2a865f7e63c..996c9b83c5c5 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -7,6 +7,7 @@ pub use arrow::datatypes::{ArrowSchema, Field as ArrowField}; pub use arrow::legacy::prelude::*; pub(crate) use arrow::trusted_len::TrustedLen; pub use polars_utils::index::{ChunkId, IdxSize, NullableChunkId, NullableIdxSize}; +pub use polars_utils::pl_str::PlSmallStr; pub(crate) use polars_utils::total_ord::{TotalEq, TotalOrd}; pub use crate::chunked_array::arithmetic::ArithmeticChunked; @@ -39,7 +40,7 @@ pub use crate::datatypes::{ArrayCollectIterExt, *}; pub use crate::error::{ polars_bail, polars_ensure, polars_err, polars_warn, PolarsError, PolarsResult, }; -pub use crate::frame::explode::UnpivotArgs; +pub use crate::frame::explode::UnpivotArgsIR; #[cfg(feature = "algorithm_group_by")] pub(crate) use crate::frame::group_by::aggregations::*; #[cfg(feature = "algorithm_group_by")] diff --git a/crates/polars-core/src/scalar/mod.rs b/crates/polars-core/src/scalar/mod.rs index 07ed78b0863f..ac7a946ebebc 100644 --- a/crates/polars-core/src/scalar/mod.rs +++ b/crates/polars-core/src/scalar/mod.rs @@ -1,19 +1,33 @@ pub mod reduce; +use polars_utils::pl_str::PlSmallStr; + use crate::datatypes::{AnyValue, DataType}; use crate::prelude::Series; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Scalar { dtype: DataType, value: AnyValue<'static>, } impl Scalar { + #[inline(always)] pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self { Self { dtype, value } } + #[inline(always)] + pub fn is_null(&self) -> bool { + self.value.is_null() + } + + #[inline(always)] + pub fn is_nan(&self) -> bool { + self.value.is_nan() + } + + #[inline(always)] pub fn value(&self) -> &AnyValue<'static> { &self.value } @@ -24,14 +38,16 @@ impl Scalar { .unwrap_or_else(|| self.value.clone()) } - pub fn into_series(self, name: &str) -> Series { + pub fn into_series(self, name: PlSmallStr) -> Series { Series::from_any_values_and_dtype(name, &[self.as_any_value()], &self.dtype, true).unwrap() } + #[inline(always)] pub fn dtype(&self) -> &DataType { &self.dtype } + #[inline(always)] pub fn update(&mut self, value: AnyValue<'static>) { self.value = value; } diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index dfb581135a0b..d100cf91172f 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -1,422 +1,85 @@ -use std::fmt::{Debug, Formatter}; -use std::hash::{Hash, Hasher}; +use std::fmt::Debug; -use arrow::datatypes::ArrowSchemaRef; -use indexmap::map::MutableKeys; -use indexmap::IndexMap; -#[cfg(feature = "serde-lazy")] -use serde::{Deserialize, Serialize}; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use crate::prelude::*; use crate::utils::try_get_supertype; -/// A map from field/column name ([`String`](smartstring::alias::String)) to the type of that field/column ([`DataType`]) -#[derive(Eq, Clone, Default)] -#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] -pub struct Schema { - inner: PlIndexMap, -} - -impl Hash for Schema { - fn hash(&self, state: &mut H) { - self.inner.iter().for_each(|v| v.hash(state)) - } -} - -// Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner == -// other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it. -impl PartialEq for Schema { - fn eq(&self, other: &Self) -> bool { - self.len() == other.len() && self.iter().zip(other.iter()).all(|(a, b)| a == b) - } -} - -impl Debug for Schema { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - writeln!(f, "Schema:")?; - for (name, dtype) in self.inner.iter() { - writeln!(f, "name: {name}, data type: {dtype:?}")?; - } - Ok(()) - } -} - -impl From<&[Series]> for Schema { - fn from(value: &[Series]) -> Self { - value.iter().map(|s| s.field().into_owned()).collect() - } -} - -impl FromIterator for Schema -where - F: Into, -{ - fn from_iter>(iter: T) -> Self { - let iter = iter.into_iter(); - let mut map: PlIndexMap<_, _> = - IndexMap::with_capacity_and_hasher(iter.size_hint().0, ahash::RandomState::default()); - for fld in iter { - let fld = fld.into(); - map.insert(fld.name, fld.dtype); - } - Self { inner: map } - } -} - -impl Schema { - /// Create a new, empty schema - pub fn new() -> Self { - Self::with_capacity(0) - } - - /// Create a new, empty schema with capacity - /// - /// If you know the number of fields you have ahead of time, using this is more efficient than using - /// [`new`][Self::new]. Also consider using [`Schema::from_iter`] if you have the collection of fields available - /// ahead of time. - pub fn with_capacity(capacity: usize) -> Self { - let map: PlIndexMap<_, _> = - IndexMap::with_capacity_and_hasher(capacity, ahash::RandomState::default()); - Self { inner: map } - } - - /// Reserve `additional` memory spaces in the schema. - pub fn reserve(&mut self, additional: usize) { - self.inner.reserve(additional); - } - - /// The number of fields in the schema - #[inline] - pub fn len(&self) -> usize { - self.inner.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Rename field `old` to `new`, and return the (owned) old name - /// - /// If `old` is not present in the schema, the schema is not modified and `None` is returned. Otherwise the schema - /// is updated and `Some(old_name)` is returned. - pub fn rename(&mut self, old: &str, new: SmartString) -> Option { - // Remove `old`, get the corresponding index and dtype, and move the last item in the map to that position - let (old_index, old_name, dtype) = self.inner.swap_remove_full(old)?; - // Insert the same dtype under the new name at the end of the map and store that index - let (new_index, _) = self.inner.insert_full(new, dtype); - // Swap the two indices to move the originally last element back to the end and to move the new element back to - // its original position - self.inner.swap_indices(old_index, new_index); - - Some(old_name) - } - - /// Create a new schema from this one, inserting a field with `name` and `dtype` at the given `index` - /// - /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is - /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the - /// end of the schema). - /// - /// For a mutating version that doesn't clone, see [`insert_at_index`][Self::insert_at_index]. - /// - /// Runtime: **O(m * n)** where `m` is the (average) length of the field names and `n` is the number of fields in - /// the schema. This method clones every field in the schema. - /// - /// Returns: `Ok(new_schema)` if `index <= self.len()`, else `Err(PolarsError)` - pub fn new_inserting_at_index( - &self, - index: usize, - name: SmartString, - dtype: DataType, - ) -> PolarsResult { - polars_ensure!( - index <= self.len(), - OutOfBounds: - "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", - index, - self.len() - ); +pub type SchemaRef = Arc; +pub type Schema = polars_schema::Schema; - let mut new = Self::default(); - let mut iter = self.inner.iter().filter_map(|(fld_name, dtype)| { - (fld_name != &name).then_some((fld_name.clone(), dtype.clone())) - }); - new.inner.extend(iter.by_ref().take(index)); - new.inner.insert(name.clone(), dtype); - new.inner.extend(iter); - Ok(new) - } +pub trait SchemaExt { + fn from_arrow_schema(value: &ArrowSchema) -> Self; - /// Insert a field with `name` and `dtype` at the given `index` into this schema - /// - /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is - /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the - /// end of the schema). - /// - /// For a non-mutating version that clones the schema, see [`new_inserting_at_index`][Self::new_inserting_at_index]. - /// - /// Runtime: **O(n)** where `n` is the number of fields in the schema. - /// - /// Returns: - /// - If index is out of bounds, `Err(PolarsError)` - /// - Else if `name` was already in the schema, `Ok(Some(old_dtype))` - /// - Else `Ok(None)` - pub fn insert_at_index( - &mut self, - mut index: usize, - name: SmartString, - dtype: DataType, - ) -> PolarsResult> { - polars_ensure!( - index <= self.len(), - OutOfBounds: - "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", - index, - self.len() - ); + fn get_field(&self, name: &str) -> Option; - let (old_index, old_dtype) = self.inner.insert_full(name, dtype); + fn try_get_field(&self, name: &str) -> PolarsResult; - // If we're moving an existing field, one-past-the-end will actually be out of bounds. Also, self.len() won't - // have changed after inserting, so `index == self.len()` is the same as it was before inserting. - if old_dtype.is_some() && index == self.len() { - index -= 1; - } - self.inner.move_index(old_index, index); - Ok(old_dtype) - } + fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema; - /// Get a reference to the dtype of the field named `name`, or `None` if the field doesn't exist - pub fn get(&self, name: &str) -> Option<&DataType> { - self.inner.get(name) - } + fn iter_fields(&self) -> impl ExactSizeIterator + '_; - /// Get a reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist - pub fn try_get(&self, name: &str) -> PolarsResult<&DataType> { - self.get(name) - .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) - } - - /// Get a mutable reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist - pub fn try_get_mut(&mut self, name: &str) -> PolarsResult<&mut DataType> { - self.inner - .get_mut(name) - .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) - } - - /// Return all data about the field named `name`: its index in the schema, its name, and its dtype - /// - /// Returns `Some((index, &name, &dtype))` if the field exists, `None` if it doesn't. - pub fn get_full(&self, name: &str) -> Option<(usize, &SmartString, &DataType)> { - self.inner.get_full(name) - } + fn to_supertype(&mut self, other: &Schema) -> PolarsResult; +} - /// Return all data about the field named `name`: its index in the schema, its name, and its dtype - /// - /// Returns `Ok((index, &name, &dtype))` if the field exists, `Err(PolarsErr)` if it doesn't. - pub fn try_get_full(&self, name: &str) -> PolarsResult<(usize, &SmartString, &DataType)> { - self.inner - .get_full(name) - .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) +impl SchemaExt for Schema { + fn from_arrow_schema(value: &ArrowSchema) -> Self { + value + .iter_values() + .map(|x| (x.name.clone(), DataType::from_arrow(&x.dtype, true))) + .collect() } - /// Look up the name in the schema and return an owned [`Field`] by cloning the data + /// Look up the name in the schema and return an owned [`Field`] by cloning the data. /// /// Returns `None` if the field does not exist. /// /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see /// [`get`][Self::get] or [`get_full`][Self::get_full]. - pub fn get_field(&self, name: &str) -> Option { - self.inner - .get(name) - .map(|dtype| Field::new(name, dtype.clone())) + fn get_field(&self, name: &str) -> Option { + self.get_full(name) + .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone())) } - /// Look up the name in the schema and return an owned [`Field`] by cloning the data + /// Look up the name in the schema and return an owned [`Field`] by cloning the data. /// /// Returns `Err(PolarsErr)` if the field does not exist. /// /// This method constructs the `Field` by cloning the name and dtype. For a version that returns references, see /// [`get`][Self::get] or [`get_full`][Self::get_full]. - pub fn try_get_field(&self, name: &str) -> PolarsResult { - self.inner - .get(name) + fn try_get_field(&self, name: &str) -> PolarsResult { + self.get_full(name) .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) - .map(|dtype| Field::new(name, dtype.clone())) - } - - /// Get references to the name and dtype of the field at `index` - /// - /// If `index` is inbounds, returns `Some((&name, &dtype))`, else `None`. See - /// [`get_at_index_mut`][Self::get_at_index_mut] for a mutable version. - pub fn get_at_index(&self, index: usize) -> Option<(&SmartString, &DataType)> { - self.inner.get_index(index) - } - - pub fn try_get_at_index(&self, index: usize) -> PolarsResult<(&SmartString, &DataType)> { - self.inner.get_index(index).ok_or_else(|| polars_err!(ComputeError: "index {index} out of bounds with 'schema' of len: {}", self.len())) - } - - /// Get mutable references to the name and dtype of the field at `index` - /// - /// If `index` is inbounds, returns `Some((&mut name, &mut dtype))`, else `None`. See - /// [`get_at_index`][Self::get_at_index] for an immutable version. - pub fn get_at_index_mut(&mut self, index: usize) -> Option<(&mut SmartString, &mut DataType)> { - self.inner.get_index_mut2(index) - } - - /// Swap-remove a field by name and, if the field existed, return its dtype - /// - /// If the field does not exist, the schema is not modified and `None` is returned. - /// - /// This method does a `swap_remove`, which is O(1) but **changes the order of the schema**: the field named `name` - /// is replaced by the last field, which takes its position. For a slower, but order-preserving, method, use - /// [`shift_remove`][Self::shift_remove]. - pub fn remove(&mut self, name: &str) -> Option { - self.inner.swap_remove(name) - } - - /// Remove a field by name, preserving order, and, if the field existed, return its dtype - /// - /// If the field does not exist, the schema is not modified and `None` is returned. - /// - /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a - /// faster, but not order-preserving, method, use [`remove`][Self::remove]. - pub fn shift_remove(&mut self, name: &str) -> Option { - self.inner.shift_remove(name) - } - - /// Remove a field by name, preserving order, and, if the field existed, return its dtype - /// - /// If the field does not exist, the schema is not modified and `None` is returned. - /// - /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a - /// faster, but not order-preserving, method, use [`remove`][Self::remove]. - pub fn shift_remove_index(&mut self, index: usize) -> Option<(SmartString, DataType)> { - self.inner.shift_remove_index(index) - } - - /// Whether the schema contains a field named `name` - pub fn contains(&self, name: &str) -> bool { - self.get(name).is_some() - } - - /// Change the field named `name` to the given `dtype` and return the previous dtype - /// - /// If `name` doesn't already exist in the schema, the schema is not modified and `None` is returned. Otherwise - /// returns `Some(old_dtype)`. - /// - /// This method only ever modifies an existing field and never adds a new field to the schema. To add a new field, - /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index]. - pub fn set_dtype(&mut self, name: &str, dtype: DataType) -> Option { - let old_dtype = self.inner.get_mut(name)?; - Some(std::mem::replace(old_dtype, dtype)) - } - - /// Change the field at the given index to the given `dtype` and return the previous dtype - /// - /// If the index is out of bounds, the schema is not modified and `None` is returned. Otherwise returns - /// `Some(old_dtype)`. - /// - /// This method only ever modifies an existing index and never adds a new field to the schema. To add a new field, - /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index]. - pub fn set_dtype_at_index(&mut self, index: usize, dtype: DataType) -> Option { - let (_, old_dtype) = self.inner.get_index_mut(index)?; - Some(std::mem::replace(old_dtype, dtype)) - } - - /// Insert a new column in the [`Schema`] - /// - /// If an equivalent name already exists in the schema: the name remains and - /// retains in its place in the order, its corresponding value is updated - /// with [`DataType`] and the older dtype is returned inside `Some(_)`. - /// - /// If no equivalent key existed in the map: the new name-dtype pair is - /// inserted, last in order, and `None` is returned. - /// - /// To enforce the index of the resulting field, use [`insert_at_index`][Self::insert_at_index]. - /// - /// Computes in **O(1)** time (amortized average). - pub fn with_column(&mut self, name: SmartString, dtype: DataType) -> Option { - self.inner.insert(name, dtype) - } - - /// Merge `other` into `self` - /// - /// Merging logic: - /// - Fields that occur in `self` but not `other` are unmodified - /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` - /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original - /// index - pub fn merge(&mut self, other: Self) { - self.inner.extend(other.inner) - } - - /// Merge borrowed `other` into `self` - /// - /// Merging logic: - /// - Fields that occur in `self` but not `other` are unmodified - /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` - /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original - /// index - pub fn merge_from_ref(&mut self, other: &Self) { - self.inner.extend( - other - .iter() - .map(|(column, datatype)| (column.clone(), datatype.clone())), - ) - } - - /// Convert self to `ArrowSchema` by cloning the fields - pub fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema { - let fields: Vec<_> = self - .inner - .iter() - .map(|(name, dtype)| dtype.to_arrow_field(name.as_str(), compat_level)) - .collect(); - ArrowSchema::from(fields) + .map(|(_, name, dtype)| Field::new(name.clone(), dtype.clone())) + } + + /// Convert self to `ArrowSchema` by cloning the fields. + fn to_arrow(&self, compat_level: CompatLevel) -> ArrowSchema { + self.iter() + .map(|(name, dtype)| { + ( + name.clone(), + dtype.to_arrow_field(name.clone(), compat_level), + ) + }) + .collect() } - /// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair + /// Iterates the [`Field`]s in this schema, constructing them anew by cloning each `(&name, &dtype)` pair. /// /// Note that this clones each name and dtype in order to form an owned [`Field`]. For a clone-free version, use /// [`iter`][Self::iter], which returns `(&name, &dtype)`. - pub fn iter_fields(&self) -> impl ExactSizeIterator + '_ { - self.inner - .iter() - .map(|(name, dtype)| Field::new(name, dtype.clone())) - } - - /// Iterates over references to the dtypes in this schema - pub fn iter_dtypes(&self) -> impl '_ + ExactSizeIterator { - self.inner.iter().map(|(_name, dtype)| dtype) - } - - /// Iterates over mut references to the dtypes in this schema - pub fn iter_dtypes_mut(&mut self) -> impl '_ + ExactSizeIterator { - self.inner.iter_mut().map(|(_name, dtype)| dtype) - } - - /// Iterates over references to the names in this schema - pub fn iter_names(&self) -> impl '_ + ExactSizeIterator { - self.inner.iter().map(|(name, _dtype)| name) - } - - /// Iterates over the `(&name, &dtype)` pairs in this schema - /// - /// For an owned version, use [`iter_fields`][Self::iter_fields], which clones the data to iterate owned `Field`s - pub fn iter(&self) -> impl Iterator + '_ { - self.inner.iter() + fn iter_fields(&self) -> impl ExactSizeIterator + '_ { + self.iter() + .map(|(name, dtype)| Field::new(name.clone(), dtype.clone())) } /// Take another [`Schema`] and try to find the supertypes between them. - pub fn to_supertype(&mut self, other: &Schema) -> PolarsResult { + fn to_supertype(&mut self, other: &Schema) -> PolarsResult { polars_ensure!(self.len() == other.len(), ComputeError: "schema lengths differ"); let mut changed = false; - for ((k, dt), (other_k, other_dt)) in self.inner.iter_mut().zip(other.iter()) { + for ((k, dt), (other_k, other_dt)) in self.iter_mut().zip(other.iter()) { polars_ensure!(k == other_k, ComputeError: "schema names differ: got {}, expected {}", k, other_k); let st = try_get_supertype(dt, other_dt)?; @@ -427,113 +90,46 @@ impl Schema { } } -pub type SchemaRef = Arc; - -impl IntoIterator for Schema { - type Item = (SmartString, DataType); - type IntoIter = as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() - } -} - -/// This trait exists to be unify the API of polars Schema and arrows Schema -pub trait IndexOfSchema: Debug { - /// Get the index of a column by name. - fn index_of(&self, name: &str) -> Option; - - /// Get a vector of all column names. - fn get_names(&self) -> Vec<&str>; - - fn try_index_of(&self, name: &str) -> PolarsResult { - self.index_of(name).ok_or_else(|| { - polars_err!( - ColumnNotFound: - "unable to find column {:?}; valid columns: {:?}", name, self.get_names(), - ) - }) - } -} - -impl IndexOfSchema for Schema { - fn index_of(&self, name: &str) -> Option { - self.inner.get_index_of(name) - } - - fn get_names(&self) -> Vec<&str> { - self.iter_names().map(|name| name.as_str()).collect() - } -} - -impl IndexOfSchema for ArrowSchema { - fn index_of(&self, name: &str) -> Option { - self.fields.iter().position(|f| f.name == name) - } - - fn get_names(&self) -> Vec<&str> { - self.fields.iter().map(|f| f.name.as_str()).collect() - } -} - pub trait SchemaNamesAndDtypes { const IS_ARROW: bool; - type DataType: Debug + PartialEq; - - /// Get a vector of (name, dtype) pairs - fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)>; -} + type DataType: Debug + Clone + Default + PartialEq; -impl SchemaNamesAndDtypes for Schema { - const IS_ARROW: bool = false; - type DataType = DataType; - - fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> { - self.inner - .iter() - .map(|(name, dtype)| (name.as_str(), dtype.clone())) - .collect() - } + fn iter_names_and_dtypes( + &self, + ) -> impl ExactSizeIterator; } impl SchemaNamesAndDtypes for ArrowSchema { const IS_ARROW: bool = true; type DataType = ArrowDataType; - fn get_names_and_dtypes(&'_ self) -> Vec<(&'_ str, Self::DataType)> { - self.fields - .iter() - .map(|x| (x.name.as_str(), x.data_type.clone())) - .collect() - } -} - -impl From<&ArrowSchema> for Schema { - fn from(value: &ArrowSchema) -> Self { - Self::from_iter(value.fields.iter()) - } -} -impl From for Schema { - fn from(value: ArrowSchema) -> Self { - Self::from(&value) + fn iter_names_and_dtypes( + &self, + ) -> impl ExactSizeIterator { + self.iter_values().map(|x| (&x.name, &x.dtype)) } } -impl From for Schema { - fn from(value: ArrowSchemaRef) -> Self { - Self::from(value.as_ref()) - } -} +impl SchemaNamesAndDtypes for Schema { + const IS_ARROW: bool = false; + type DataType = DataType; -impl From<&ArrowSchemaRef> for Schema { - fn from(value: &ArrowSchemaRef) -> Self { - Self::from(value.as_ref()) + fn iter_names_and_dtypes( + &self, + ) -> impl ExactSizeIterator { + self.iter() } } -pub fn ensure_matching_schema(lhs: &S, rhs: &S) -> PolarsResult<()> { - let lhs = lhs.get_names_and_dtypes(); - let rhs = rhs.get_names_and_dtypes(); +pub fn ensure_matching_schema( + lhs: &polars_schema::Schema, + rhs: &polars_schema::Schema, +) -> PolarsResult<()> +where + polars_schema::Schema: SchemaNamesAndDtypes, +{ + let lhs = lhs.iter_names_and_dtypes(); + let rhs = rhs.iter_names_and_dtypes(); if lhs.len() != rhs.len() { polars_bail!( @@ -543,7 +139,7 @@ pub fn ensure_matching_schema(lhs: &S, rhs: &S) -> Pola ); } - for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.iter().zip(&rhs).enumerate() { + for (i, ((l_name, l_dtype), (r_name, r_dtype))) in lhs.zip(rhs).enumerate() { if l_name != r_name { polars_bail!( SchemaMismatch: @@ -552,18 +148,20 @@ pub fn ensure_matching_schema(lhs: &S, rhs: &S) -> Pola ) } if l_dtype != r_dtype - && (!S::IS_ARROW + && (!polars_schema::Schema::::IS_ARROW || unsafe { // For timezone normalization. Easier than writing out the entire PartialEq. DataType::from_arrow( - std::mem::transmute::<&::DataType, &ArrowDataType>( - l_dtype, - ), + std::mem::transmute::< + & as SchemaNamesAndDtypes>::DataType, + &ArrowDataType, + >(l_dtype), true, ) != DataType::from_arrow( - std::mem::transmute::<&::DataType, &ArrowDataType>( - r_dtype, - ), + std::mem::transmute::< + & as SchemaNamesAndDtypes>::DataType, + &ArrowDataType, + >(r_dtype), true, ) }) diff --git a/crates/polars-core/src/serde/chunked_array.rs b/crates/polars-core/src/serde/chunked_array.rs index 15b8358d62e3..145f05c9af38 100644 --- a/crates/polars-core/src/serde/chunked_array.rs +++ b/crates/polars-core/src/serde/chunked_array.rs @@ -46,7 +46,7 @@ where fn serialize_impl( serializer: S, - name: &str, + name: &PlSmallStr, dtype: &DataType, bit_settings: MetadataFlags, ca: &ChunkedArray, diff --git a/crates/polars-core/src/serde/mod.rs b/crates/polars-core/src/serde/mod.rs index b0157956d8cf..86fbf5c52007 100644 --- a/crates/polars-core/src/serde/mod.rs +++ b/crates/polars-core/src/serde/mod.rs @@ -10,14 +10,14 @@ mod test { #[test] fn test_serde() -> PolarsResult<()> { - let ca = UInt32Chunked::new("foo", &[Some(1), None, Some(2)]); + let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(2)]); let json = serde_json::to_string(&ca).unwrap(); let out = serde_json::from_str::(&json).unwrap(); assert!(ca.into_series().equals_missing(&out)); - let ca = StringChunked::new("foo", &[Some("foo"), None, Some("bar")]); + let ca = StringChunked::new("foo".into(), &[Some("foo"), None, Some("bar")]); let json = serde_json::to_string(&ca).unwrap(); @@ -30,7 +30,7 @@ mod test { /// test using the `DeserializedOwned` trait #[test] fn test_serde_owned() { - let ca = UInt32Chunked::new("foo", &[Some(1), None, Some(2)]); + let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(2)]); let json = serde_json::to_string(&ca).unwrap(); @@ -39,10 +39,10 @@ mod test { } fn sample_dataframe() -> DataFrame { - let s1 = Series::new("foo", &[1, 2, 3]); - let s2 = Series::new("bar", &[Some(true), None, Some(false)]); - let s3 = Series::new("string", &["mouse", "elephant", "dog"]); - let s_list = Series::new("list", &[s1.clone(), s1.clone(), s1.clone()]); + let s1 = Series::new("foo".into(), &[1, 2, 3]); + let s2 = Series::new("bar".into(), &[Some(true), None, Some(false)]); + let s3 = Series::new("string".into(), &["mouse", "elephant", "dog"]); + let s_list = Series::new("list".into(), &[s1.clone(), s1.clone(), s1.clone()]); DataFrame::new(vec![s1, s2, s3, s_list]).unwrap() } @@ -90,7 +90,7 @@ mod test { #[test] fn test_serde_binary_series_owned_bincode() { let s1 = Series::new( - "foo", + "foo".into(), &[ vec![1u8, 2u8, 3u8], vec![4u8, 5u8, 6u8, 7u8], @@ -115,15 +115,15 @@ mod test { AnyValue::String("1:3"), ], vec![ - Field::new("fld_1", DataType::String), - Field::new("fld_2", DataType::String), - Field::new("fld_3", DataType::String), + Field::new("fld_1".into(), DataType::String), + Field::new("fld_2".into(), DataType::String), + Field::new("fld_3".into(), DataType::String), ], ))); let dtype = DataType::Struct(vec![ - Field::new("fld_1", DataType::String), - Field::new("fld_2", DataType::String), - Field::new("fld_3", DataType::String), + Field::new("fld_1".into(), DataType::String), + Field::new("fld_2".into(), DataType::String), + Field::new("fld_3".into(), DataType::String), ]); let row_2 = AnyValue::StructOwned(Box::new(( vec![ @@ -132,15 +132,16 @@ mod test { AnyValue::String("2:3"), ], vec![ - Field::new("fld_1", DataType::String), - Field::new("fld_2", DataType::String), - Field::new("fld_3", DataType::String), + Field::new("fld_1".into(), DataType::String), + Field::new("fld_2".into(), DataType::String), + Field::new("fld_3".into(), DataType::String), ], ))); let row_3 = AnyValue::Null; - let s = Series::from_any_values_and_dtype("item", &[row_1, row_2, row_3], &dtype, false) - .unwrap(); + let s = + Series::from_any_values_and_dtype("item".into(), &[row_1, row_2, row_3], &dtype, false) + .unwrap(); let df = DataFrame::new(vec![s]).unwrap(); let df_str = serde_json::to_string(&df).unwrap(); diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 006edd96d604..3506a0e9cc89 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::fmt::Formatter; use serde::de::{Error as DeError, MapAccess, Visitor}; +#[cfg(feature = "object")] use serde::ser::Error as SerError; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; @@ -145,95 +146,96 @@ impl<'de> Deserialize<'de> for Series { return Err(de::Error::missing_field("values")); } let name = name.ok_or_else(|| de::Error::missing_field("name"))?; + let name = PlSmallStr::from_str(name.as_ref()); let dtype = dtype.ok_or_else(|| de::Error::missing_field("datatype"))?; let mut s = match dtype { #[cfg(feature = "dtype-i8")] DataType::Int8 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, #[cfg(feature = "dtype-u8")] DataType::UInt8 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, #[cfg(feature = "dtype-i16")] DataType::Int16 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, #[cfg(feature = "dtype-u16")] DataType::UInt16 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::Int32 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::UInt32 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::Int64 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::UInt64 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, #[cfg(feature = "dtype-date")] DataType::Date => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values).cast(&DataType::Date).unwrap()) + Ok(Series::new(name, values).cast(&DataType::Date).unwrap()) }, #[cfg(feature = "dtype-datetime")] DataType::Datetime(tu, tz) => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values) + Ok(Series::new(name, values) .cast(&DataType::Datetime(tu, tz)) .unwrap()) }, #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values) + Ok(Series::new(name, values) .cast(&DataType::Duration(tu)) .unwrap()) }, #[cfg(feature = "dtype-time")] DataType::Time => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values).cast(&DataType::Time).unwrap()) + Ok(Series::new(name, values).cast(&DataType::Time).unwrap()) }, #[cfg(feature = "dtype-decimal")] DataType::Decimal(precision, Some(scale)) => { let values: Vec> = map.next_value()?; - Ok(ChunkedArray::from_slice_options(&name, &values) + Ok(ChunkedArray::from_slice_options(name, &values) .into_decimal_unchecked(precision, scale) .into_series()) }, DataType::Boolean => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::Float32 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::Float64 => { let values: Vec> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::String => { let values: Vec>> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, DataType::List(inner) => { let values: Vec> = map.next_value()?; - let mut lb = AnonymousListBuilder::new(&name, values.len(), Some(*inner)); + let mut lb = AnonymousListBuilder::new(name, values.len(), Some(*inner)); for value in &values { lb.append_opt_series(value.as_ref()).map_err(|e| { de::Error::custom(format!("could not append series to list: {e}")) @@ -245,7 +247,7 @@ impl<'de> Deserialize<'de> for Series { DataType::Array(inner, width) => { let values: Vec> = map.next_value()?; let mut builder = - get_fixed_size_list_builder(&inner, values.len(), width, &name) + get_fixed_size_list_builder(&inner, values.len(), width, name) .map_err(|e| { de::Error::custom(format!( "could not get supported list builder: {e}" @@ -270,25 +272,25 @@ impl<'de> Deserialize<'de> for Series { }, DataType::Binary => { let values: Vec>> = map.next_value()?; - Ok(Series::new(&name, values)) + Ok(Series::new(name, values)) }, #[cfg(feature = "dtype-struct")] DataType::Struct(_) => { let values: Vec = map.next_value()?; - let ca = StructChunked::from_series(&name, &values).unwrap(); + let ca = StructChunked::from_series(name.clone(), &values).unwrap(); let mut s = ca.into_series(); - s.rename(&name); + s.rename(name); Ok(s) }, #[cfg(feature = "dtype-categorical")] dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { let values: Vec>> = map.next_value()?; - Ok(Series::new(&name, values).cast(&dt).unwrap()) + Ok(Series::new(name, values).cast(&dt).unwrap()) }, DataType::Null => { let values: Vec = map.next_value()?; let len = values.first().unwrap(); - Ok(Series::new_null(&name, *len)) + Ok(Series::new_null(name, *len)) }, dt => Err(A::Error::custom(format!( "deserializing data of type {dt} is not supported" diff --git a/crates/polars-core/src/series/amortized_iter.rs b/crates/polars-core/src/series/amortized_iter.rs index 7cdf8507c29f..7d32bfcb4bf5 100644 --- a/crates/polars-core/src/series/amortized_iter.rs +++ b/crates/polars-core/src/series/amortized_iter.rs @@ -51,8 +51,8 @@ impl AmortSeries { let s = &(*self.container); debug_assert_eq!(s.chunks().len(), 1); let array_ref = s.chunks().get_unchecked(0).clone(); - let name = s.name(); - Series::from_chunks_and_dtype_unchecked(name, vec![array_ref], s.dtype()) + let name = s.name().clone(); + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![array_ref], s.dtype()) } } @@ -93,7 +93,7 @@ impl AmortSeries { // SAFETY: // type must be matching pub(crate) unsafe fn unstable_series_container_and_ptr( - name: &str, + name: PlSmallStr, inner_values: ArrayRef, iter_dtype: &DataType, ) -> (Series, *mut ArrayRef) { diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 83abf75e980d..aaa4bc753443 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -18,7 +18,7 @@ impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom]> for Series { /// of [`DataType::Null`], which is always allowed). /// /// [`AnyValue`]: crate::datatypes::AnyValue - fn new(name: &str, values: T) -> Self { + fn new(name: PlSmallStr, values: T) -> Self { let values = values.as_ref(); Series::from_any_values(name, values, true).expect("data types of values should match") } @@ -36,7 +36,11 @@ impl Series { /// An error is returned if no supertype can be determined. /// **WARNING**: A full pass over the values is required to determine the supertype. /// - If no values were passed, the resulting data type is `Null`. - pub fn from_any_values(name: &str, values: &[AnyValue], strict: bool) -> PolarsResult { + pub fn from_any_values( + name: PlSmallStr, + values: &[AnyValue], + strict: bool, + ) -> PolarsResult { fn get_first_non_null_dtype(values: &[AnyValue]) -> DataType { let mut all_flat_null = true; let first_non_null = values.iter().find(|av| { @@ -82,7 +86,7 @@ impl Series { /// data type. If `strict` is `false`, values that do not match the given data type /// are cast. If casting is not possible, the values are set to null instead. pub fn from_any_values_and_dtype( - name: &str, + name: PlSmallStr, values: &[AnyValue], dtype: &DataType, strict: bool, @@ -158,7 +162,7 @@ impl Series { DataType::Struct(fields) => any_values_to_struct(values, fields, strict)?, #[cfg(feature = "object")] DataType::Object(_, registry) => any_values_to_object(values, registry)?, - DataType::Null => Series::new_null(name, values.len()), + DataType::Null => Series::new_null(PlSmallStr::EMPTY, values.len()), dt => { polars_bail!( InvalidOperation: @@ -185,9 +189,9 @@ fn any_values_to_integer( fn any_values_to_integer_strict( values: &[AnyValue], ) -> PolarsResult> { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); for av in values { - match av { + match &av { av if av.is_integer() => { let opt_val = av.extract::(); let val = match opt_val { @@ -212,7 +216,8 @@ fn any_values_to_integer( fn any_values_to_f32(values: &[AnyValue], strict: bool) -> PolarsResult { fn any_values_to_f32_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = + PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Float32(i) => builder.append_value(*i), @@ -230,7 +235,8 @@ fn any_values_to_f32(values: &[AnyValue], strict: bool) -> PolarsResult PolarsResult { fn any_values_to_f64_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = + PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Float64(i) => builder.append_value(*i), @@ -249,7 +255,7 @@ fn any_values_to_f64(values: &[AnyValue], strict: bool) -> PolarsResult PolarsResult { - let mut builder = BooleanChunkedBuilder::new("", values.len()); + let mut builder = BooleanChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Boolean(b) => builder.append_value(*b), @@ -270,7 +276,7 @@ fn any_values_to_bool(values: &[AnyValue], strict: bool) -> PolarsResult PolarsResult { fn any_values_to_string_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = StringChunkedBuilder::new("", values.len()); + let mut builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::String(s) => builder.append_value(s), @@ -282,7 +288,7 @@ fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult StringChunked { - let mut builder = StringChunkedBuilder::new("", values.len()); + let mut builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); let mut owned = String::new(); // Amortize allocations. for av in values { match av { @@ -308,7 +314,7 @@ fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult PolarsResult { fn any_values_to_binary_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = BinaryChunkedBuilder::new("", values.len()); + let mut builder = BinaryChunkedBuilder::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Binary(s) => builder.append_value(*s), @@ -326,7 +332,7 @@ fn any_values_to_binary(values: &[AnyValue], strict: bool) -> PolarsResult Some(*b), AnyValue::BinaryOwned(b) => Some(&**b), AnyValue::String(s) => Some(s.as_bytes()), - AnyValue::StringOwned(s) => Some(s.as_bytes()), + AnyValue::StringOwned(s) => Some(s.as_str().as_bytes()), _ => None, }) .collect_trusted() @@ -340,7 +346,7 @@ fn any_values_to_binary(values: &[AnyValue], strict: bool) -> PolarsResult PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Date(i) => builder.append_value(*i), @@ -361,7 +367,7 @@ fn any_values_to_date(values: &[AnyValue], strict: bool) -> PolarsResult PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Time(i) => builder.append_value(*i), @@ -387,7 +393,7 @@ fn any_values_to_datetime( time_zone: Option, strict: bool, ) -> PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); let target_dtype = DataType::Datetime(time_unit, time_zone.clone()); for av in values { match av { @@ -413,7 +419,7 @@ fn any_values_to_duration( time_unit: TimeUnit, strict: bool, ) -> PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); let target_dtype = DataType::Duration(time_unit); for av in values { match av { @@ -485,7 +491,7 @@ fn any_values_to_decimal( }; let target_dtype = DataType::Decimal(precision, Some(scale)); - let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + let mut builder = PrimitiveChunkedBuilder::::new(PlSmallStr::EMPTY, values.len()); for av in values { match av { // Allow equal or less scale. We do want to support different scales even in 'strict' mode. @@ -527,7 +533,7 @@ fn any_values_to_list( // We must ensure the data-types match what we do physical #[cfg(feature = "dtype-struct")] DataType::Struct(fields) if fields.is_empty() => { - DataType::Struct(vec![Field::new("", DataType::Null)]) + DataType::Struct(vec![Field::new(PlSmallStr::EMPTY, DataType::Null)]) }, _ => inner_type.clone(), }; @@ -558,7 +564,9 @@ fn any_values_to_list( } else { match b.cast(inner_type) { Ok(out) => Some(out), - Err(_) => Some(Series::full_null(b.name(), b.len(), inner_type)), + Err(_) => { + Some(Series::full_null(b.name().clone(), b.len(), inner_type)) + }, } } }, @@ -617,7 +625,7 @@ fn any_values_to_array( None }, }) - .collect_ca_with_dtype("", target_dtype.clone()) + .collect_ca_with_dtype(PlSmallStr::EMPTY, target_dtype.clone()) } // Make sure that wrongly inferred AnyValues don't deviate from the datatype. else { @@ -629,7 +637,7 @@ fn any_values_to_array( } else { let s = match b.cast(inner_type) { Ok(out) => out, - Err(_) => Series::full_null(b.name(), b.len(), inner_type), + Err(_) => Series::full_null(b.name().clone(), b.len(), inner_type), }; to_arr(&s) } @@ -640,7 +648,7 @@ fn any_values_to_array( None }, }) - .collect_ca_with_dtype("", target_dtype.clone()) + .collect_ca_with_dtype(PlSmallStr::EMPTY, target_dtype.clone()) }; if strict && !valid { @@ -670,7 +678,7 @@ fn any_values_to_struct( ) -> PolarsResult { // Fast path for structs with no fields. if fields.is_empty() { - return Ok(StructChunked::full_null("", values.len()).into_series()); + return Ok(StructChunked::full_null(PlSmallStr::EMPTY, values.len()).into_series()); } // The physical series fields of the struct. @@ -723,14 +731,19 @@ fn any_values_to_struct( } // If the inferred dtype is null, we let auto inference work. let s = if matches!(field.dtype, DataType::Null) { - Series::from_any_values(field.name(), &field_avs, strict)? + Series::from_any_values(field.name().clone(), &field_avs, strict)? } else { - Series::from_any_values_and_dtype(field.name(), &field_avs, &field.dtype, strict)? + Series::from_any_values_and_dtype( + field.name().clone(), + &field_avs, + &field.dtype, + strict, + )? }; series_fields.push(s) } - let mut out = StructChunked::from_series("", &series_fields)?; + let mut out = StructChunked::from_series(PlSmallStr::EMPTY, &series_fields)?; if has_outer_validity { let mut validity = MutableBitmap::new(); validity.extend_constant(values.len(), true); @@ -753,7 +766,7 @@ fn any_values_to_object( None => { use crate::chunked_array::object::registry; let converter = registry::get_object_converter(); - let mut builder = registry::get_object_builder("", values.len()); + let mut builder = registry::get_object_builder(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Object(val) => builder.append_value(val.as_any()), @@ -769,7 +782,7 @@ fn any_values_to_object( builder }, Some(registry) => { - let mut builder = (*registry.builder_constructor)("", values.len()); + let mut builder = (*registry.builder_constructor)(PlSmallStr::EMPTY, values.len()); for av in values { match av { AnyValue::Object(val) => builder.append_value(val.as_any()), diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index f6f80107b081..65cce0e9450d 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -714,7 +714,7 @@ impl Mul for &Series { (_, Duration(_)) => { // swap order let out = rhs.multiply(self)?; - Ok(out.with_name(self.name())) + Ok(out.with_name(self.name().clone())) }, _ => { let (lhs, rhs) = coerce_lhs_rhs(self, rhs)?; @@ -1058,7 +1058,7 @@ mod test { #[allow(clippy::eq_op)] fn test_arithmetic_series() -> PolarsResult<()> { // Series +-/* Series - let s = Series::new("foo", [1, 2, 3]); + let s = Series::new("foo".into(), [1, 2, 3]); assert_eq!( Vec::from((&s * &s)?.i32().unwrap()), [Some(1), Some(4), Some(9)] @@ -1115,9 +1115,9 @@ mod test { [Some(0), Some(1), Some(1)] ); - assert_eq!((&s * &s)?.name(), "foo"); - assert_eq!((&s * 1).name(), "foo"); - assert_eq!((1.div(&s)).name(), "foo"); + assert_eq!((&s * &s)?.name().as_str(), "foo"); + assert_eq!((&s * 1).name().as_str(), "foo"); + assert_eq!((1.div(&s)).name().as_str(), "foo"); Ok(()) } @@ -1125,13 +1125,13 @@ mod test { #[test] #[cfg(feature = "checked_arithmetic")] fn test_checked_div() { - let s = Series::new("foo", [1i32, 0, 1]); + let s = Series::new("foo".into(), [1i32, 0, 1]); let out = s.checked_div(&s).unwrap(); assert_eq!(Vec::from(out.i32().unwrap()), &[Some(1), None, Some(1)]); let out = s.checked_div_num(0).unwrap(); assert_eq!(Vec::from(out.i32().unwrap()), &[None, None, None]); - let s_f32 = Series::new("float32", [1.0f32, 0.0, 1.0]); + let s_f32 = Series::new("float32".into(), [1.0f32, 0.0, 1.0]); let out = s_f32.checked_div(&s_f32).unwrap(); assert_eq!( Vec::from(out.f32().unwrap()), @@ -1140,7 +1140,7 @@ mod test { let out = s_f32.checked_div_num(0.0f32).unwrap(); assert_eq!(Vec::from(out.f32().unwrap()), &[None, None, None]); - let s_f64 = Series::new("float64", [1.0f64, 0.0, 1.0]); + let s_f64 = Series::new("float64".into(), [1.0f64, 0.0, 1.0]); let out = s_f64.checked_div(&s_f64).unwrap(); assert_eq!( Vec::from(out.f64().unwrap()), diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index cdb5aea3bcc9..6ccb4db7c219 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -17,21 +17,21 @@ macro_rules! impl_compare { .categorical() .unwrap() .$method(rhs.categorical().unwrap())? - .with_name(lhs.name())); + .with_name(lhs.name().clone())); }, (Categorical(_, _) | Enum(_, _), String) => { return Ok(lhs .categorical() .unwrap() .$method(rhs.str().unwrap())? - .with_name(lhs.name())); + .with_name(lhs.name().clone())); }, (String, Categorical(_, _) | Enum(_, _)) => { return Ok(rhs .categorical() .unwrap() .$method(lhs.str().unwrap())? - .with_name(lhs.name())); + .with_name(lhs.name().clone())); }, _ => (), }; @@ -80,7 +80,7 @@ macro_rules! impl_compare { dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()), }; - out.rename(lhs.name()); + out.rename(lhs.name().clone()); PolarsResult::Ok(out) }}; } @@ -240,7 +240,7 @@ impl ChunkCompare<&str> for Series { DataType::Categorical(_, _) | DataType::Enum(_, _) => { self.categorical().unwrap().equal(rhs) }, - _ => Ok(BooleanChunked::full(self.name(), false, self.len())), + _ => Ok(BooleanChunked::full(self.name().clone(), false, self.len())), } } @@ -252,7 +252,11 @@ impl ChunkCompare<&str> for Series { DataType::Categorical(_, _) | DataType::Enum(_, _) => { self.categorical().unwrap().equal_missing(rhs) }, - _ => Ok(replace_non_null(self.name(), self.0.chunks(), false)), + _ => Ok(replace_non_null( + self.name().clone(), + self.0.chunks(), + false, + )), } } @@ -264,7 +268,7 @@ impl ChunkCompare<&str> for Series { DataType::Categorical(_, _) | DataType::Enum(_, _) => { self.categorical().unwrap().not_equal(rhs) }, - _ => Ok(BooleanChunked::full(self.name(), true, self.len())), + _ => Ok(BooleanChunked::full(self.name().clone(), true, self.len())), } } @@ -276,7 +280,7 @@ impl ChunkCompare<&str> for Series { DataType::Categorical(_, _) | DataType::Enum(_, _) => { self.categorical().unwrap().not_equal_missing(rhs) }, - _ => Ok(replace_non_null(self.name(), self.0.chunks(), true)), + _ => Ok(replace_non_null(self.name().clone(), self.0.chunks(), true)), } } diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 5062b7230476..ce473a4d60fb 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -31,7 +31,7 @@ impl Series { /// /// The caller must ensure that the given `dtype`'s physical type matches all the `ArrayRef` dtypes. pub unsafe fn from_chunks_and_dtype_unchecked( - name: &str, + name: PlSmallStr, chunks: Vec, dtype: &DataType, ) -> Self { @@ -121,7 +121,7 @@ impl Series { // (the pid is checked before dereference) { let pe = PolarsExtension::new(arr.clone()); - let s = pe.get_series(name); + let s = pe.get_series(&name); pe.take_and_forget(); s } @@ -138,7 +138,7 @@ impl Series { /// # Safety /// The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes. pub unsafe fn _try_from_arrow_unchecked( - name: &str, + name: PlSmallStr, chunks: Vec, dtype: &ArrowDataType, ) -> PolarsResult { @@ -150,7 +150,7 @@ impl Series { /// # Safety /// The caller must ensure that the given `dtype` matches all the `ArrayRef` dtypes. pub unsafe fn _try_from_arrow_unchecked_with_md( - name: &str, + name: PlSmallStr, chunks: Vec, dtype: &ArrowDataType, md: Option<&Metadata>, @@ -393,7 +393,7 @@ impl Series { // (the pid is checked before dereference) let s = { let pe = PolarsExtension::new(arr.clone()); - let s = pe.get_series(name); + let s = pe.get_series(&name); pe.take_and_forget(); s }; @@ -459,7 +459,7 @@ impl Series { } } -fn map_arrays_to_series(name: &str, chunks: Vec) -> PolarsResult { +fn map_arrays_to_series(name: PlSmallStr, chunks: Vec) -> PolarsResult { let chunks = chunks .iter() .map(|arr| { @@ -468,9 +468,9 @@ fn map_arrays_to_series(name: &str, chunks: Vec) -> PolarsResult::default_datatype(inner.data_type().clone()); + let dtype = ListArray::::default_datatype(inner.dtype().clone()); Box::new(ListArray::::new( - data_type, + dtype, arr.offsets().clone(), inner, arr.validity().cloned(), @@ -490,7 +490,7 @@ unsafe fn to_physical_and_dtype( arrays: Vec, md: Option<&Metadata>, ) -> (Vec, DataType) { - match arrays[0].data_type() { + match arrays[0].dtype() { ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => { let chunks = cast_chunks(&arrays, &DataType::String, CastOptions::NonStrict).unwrap(); (chunks, DataType::String) @@ -504,7 +504,7 @@ unsafe fn to_physical_and_dtype( feature_gated!("dtype-categorical", { let s = unsafe { let dt = dt.clone(); - Series::_try_from_arrow_unchecked_with_md("", arrays, &dt, md) + Series::_try_from_arrow_unchecked_with_md(PlSmallStr::EMPTY, arrays, &dt, md) } .unwrap(); (s.chunks().clone(), s.dtype().clone()) @@ -538,7 +538,7 @@ unsafe fn to_physical_and_dtype( let arr = arr.as_any().downcast_ref::().unwrap(); let dtype = - FixedSizeListArray::default_datatype(values.data_type().clone(), *size); + FixedSizeListArray::default_datatype(values.dtype().clone(), *size); Box::from(FixedSizeListArray::new( dtype, values, @@ -566,7 +566,7 @@ unsafe fn to_physical_and_dtype( .map(|(arr, values)| { let arr = arr.as_any().downcast_ref::>().unwrap(); - let dtype = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); Box::from(ListArray::::new( dtype, arr.offsets().clone(), @@ -596,7 +596,9 @@ unsafe fn to_physical_and_dtype( let arrow_fields = values .iter() .zip(_fields.iter()) - .map(|(arr, field)| ArrowField::new(&field.name, arr.data_type().clone(), true)) + .map(|(arr, field)| { + ArrowField::new(field.name.clone(), arr.dtype().clone(), true) + }) .collect(); let arrow_array = Box::new(StructArray::new( ArrowDataType::Struct(arrow_fields), @@ -606,7 +608,7 @@ unsafe fn to_physical_and_dtype( let polars_fields = _fields .iter() .zip(dtypes) - .map(|(field, dtype)| Field::new(&field.name, dtype)) + .map(|(field, dtype)| Field::new(field.name.clone(), dtype)) .collect(); (vec![arrow_array], DataType::Struct(polars_fields)) }) @@ -620,7 +622,7 @@ unsafe fn to_physical_and_dtype( | ArrowDataType::Decimal(_, _) | ArrowDataType::Date64) => { let dt = dt.clone(); - let mut s = Series::_try_from_arrow_unchecked("", arrays, &dt).unwrap(); + let mut s = Series::_try_from_arrow_unchecked(PlSmallStr::EMPTY, arrays, &dt).unwrap(); let dtype = s.dtype().clone(); (std::mem::take(s.chunks_mut()), dtype) }, @@ -633,39 +635,53 @@ unsafe fn to_physical_and_dtype( fn check_types(chunks: &[ArrayRef]) -> PolarsResult { let mut chunks_iter = chunks.iter(); - let data_type: ArrowDataType = chunks_iter + let dtype: ArrowDataType = chunks_iter .next() .ok_or_else(|| polars_err!(NoData: "expected at least one array-ref"))? - .data_type() + .dtype() .clone(); for chunk in chunks_iter { - if chunk.data_type() != &data_type { + if chunk.dtype() != &dtype { polars_bail!( ComputeError: "cannot create series from multiple arrays with different types" ); } } - Ok(data_type) + Ok(dtype) +} + +impl Series { + pub fn try_new( + name: PlSmallStr, + data: T, + ) -> Result>::Error> + where + (PlSmallStr, T): TryInto, + { + // # TODO + // * Remove the TryFrom impls in favor of this + <(PlSmallStr, T) as TryInto>::try_into((name, data)) + } } -impl TryFrom<(&str, Vec)> for Series { +impl TryFrom<(PlSmallStr, Vec)> for Series { type Error = PolarsError; - fn try_from(name_arr: (&str, Vec)) -> PolarsResult { + fn try_from(name_arr: (PlSmallStr, Vec)) -> PolarsResult { let (name, chunks) = name_arr; - let data_type = check_types(&chunks)?; + let dtype = check_types(&chunks)?; // SAFETY: // dtype is checked - unsafe { Series::_try_from_arrow_unchecked(name, chunks, &data_type) } + unsafe { Series::_try_from_arrow_unchecked(name, chunks, &dtype) } } } -impl TryFrom<(&str, ArrayRef)> for Series { +impl TryFrom<(PlSmallStr, ArrayRef)> for Series { type Error = PolarsError; - fn try_from(name_arr: (&str, ArrayRef)) -> PolarsResult { + fn try_from(name_arr: (PlSmallStr, ArrayRef)) -> PolarsResult { let (name, arr) = name_arr; Series::try_from((name, vec![arr])) } @@ -677,15 +693,15 @@ impl TryFrom<(&ArrowField, Vec)> for Series { fn try_from(field_arr: (&ArrowField, Vec)) -> PolarsResult { let (field, chunks) = field_arr; - let data_type = check_types(&chunks)?; + let dtype = check_types(&chunks)?; // SAFETY: // dtype is checked unsafe { Series::_try_from_arrow_unchecked_with_md( - &field.name, + field.name.clone(), chunks, - &data_type, + &dtype, Some(&field.metadata), ) } @@ -772,7 +788,7 @@ unsafe impl IntoSeries for Series { } } -fn new_null(name: &str, chunks: &[ArrayRef]) -> Series { +fn new_null(name: PlSmallStr, chunks: &[ArrayRef]) -> Series { let len = chunks.iter().map(|arr| arr.len()).sum(); Series::new_null(name, len) } diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index bc3ed6d23243..51bd084cd46d 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -4,7 +4,6 @@ use std::borrow::Cow; use super::{private, MetadataFlags}; use crate::chunked_array::cast::CastOptions; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; @@ -19,7 +18,7 @@ impl private::PrivateSeries for SeriesWrap { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _get_flags(&self) -> MetadataFlags { @@ -30,10 +29,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } - unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { self.0.equal_element(idx_self, idx_other, other) } @@ -73,14 +68,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -146,8 +141,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, options) + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 4c92802861a3..8cdf326302d1 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -13,7 +13,7 @@ impl private::PrivateSeries for SeriesWrap { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _get_flags(&self) -> MetadataFlags { self.0.get_flags() @@ -21,9 +21,6 @@ impl private::PrivateSeries for SeriesWrap { fn _set_flags(&mut self, flags: MetadataFlags) { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { self.0.equal_element(idx_self, idx_other, other) @@ -40,12 +37,16 @@ impl private::PrivateSeries for SeriesWrap { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -95,14 +96,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -127,13 +128,13 @@ impl SeriesTrait for SeriesWrap { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); // todo! add object - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); - self.0.extend(other.as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref())?; Ok(()) } @@ -169,8 +170,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, options) + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index e3612e2cbe15..9ff8cd6704d0 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -13,7 +13,7 @@ impl private::PrivateSeries for SeriesWrap { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _get_flags(&self) -> MetadataFlags { self.0.get_flags() @@ -33,12 +33,16 @@ impl private::PrivateSeries for SeriesWrap { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -58,14 +62,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -90,13 +94,13 @@ impl SeriesTrait for SeriesWrap { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); // todo! add object - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); - self.0.extend(other.as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref())?; Ok(()) } @@ -138,8 +142,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, options) + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 275e2d1a26a7..aae8a5837af8 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -12,7 +12,7 @@ impl private::PrivateSeries for SeriesWrap { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _get_flags(&self) -> MetadataFlags { self.0.get_flags() @@ -20,9 +20,6 @@ impl private::PrivateSeries for SeriesWrap { fn _set_flags(&mut self, flags: MetadataFlags) { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { self.0.equal_element(idx_self, idx_other, other) @@ -39,12 +36,16 @@ impl private::PrivateSeries for SeriesWrap { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -120,14 +121,14 @@ impl SeriesTrait for SeriesWrap { Ok((&self.0).bitor(other).into_series()) } - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -151,13 +152,13 @@ impl SeriesTrait for SeriesWrap { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); - self.0.extend(other.as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref())?; Ok(()) } @@ -165,6 +166,10 @@ impl SeriesTrait for SeriesWrap { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0.sum().unwrap() as f64 + } + fn mean(&self) -> Option { self.0.mean() } @@ -197,8 +202,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, options) + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 5db3344726e7..497ff5267d88 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -62,14 +62,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - // TODO! explode by offset should return concrete type - self.with_state(true, |cats| { - cats.explode_by_offsets(offsets).u32().unwrap().clone() - }) - .into_series() - } - unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { self.0.physical().equal_element(idx_self, idx_other, other) } @@ -88,12 +80,16 @@ impl private::PrivateSeries for SeriesWrap { } } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.physical().vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -129,14 +125,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.physical_mut().rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.physical().chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.physical().name() } @@ -176,7 +172,7 @@ impl SeriesTrait for SeriesWrap { (RevMapping::Global(_, _, idl), RevMapping::Global(_, _, idr)) if idl == idr => { let mut rev_map_merger = GlobalRevMapMerger::new(rev_map_self.clone()); rev_map_merger.merge_map(rev_map_other)?; - self.0.physical_mut().extend(other_ca.physical()); + self.0.physical_mut().extend(other_ca.physical())?; // SAFETY: rev_maps are merged unsafe { self.0.set_rev_map(rev_map_merger.finish(), false) }; Ok(()) @@ -223,8 +219,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, options) + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index 0585952da0ce..834449e73992 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -39,10 +39,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets).into_date().into_series() - } - #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); @@ -51,12 +47,16 @@ impl private::PrivateSeries for SeriesWrap { .map(|ca| ca.into_date().into_series()) } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -140,14 +140,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -170,6 +170,10 @@ impl SeriesTrait for SeriesWrap { (a.into_date().into_series(), b.into_date().into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } @@ -185,7 +189,7 @@ impl SeriesTrait for SeriesWrap { // ref Cow // ref SeriesTrait // ref ChunkedArray - self.0.append(other.as_ref().as_ref().as_ref()); + self.0.append(other.as_ref().as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { @@ -195,7 +199,7 @@ impl SeriesTrait for SeriesWrap { // ref SeriesTrait // ref ChunkedArray let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref().as_ref())?; Ok(()) } @@ -234,8 +238,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - match data_type { + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + match dtype { DataType::String => Ok(self .0 .clone() @@ -246,13 +250,11 @@ impl SeriesTrait for SeriesWrap { .into_series()), #[cfg(feature = "dtype-datetime")] DataType::Datetime(_, _) => { - let mut out = self - .0 - .cast_with_options(data_type, CastOptions::NonStrict)?; + let mut out = self.0.cast_with_options(dtype, CastOptions::NonStrict)?; out.set_sorted_flag(self.0.is_sorted_flag()); Ok(out) }, - _ => self.0.cast_with_options(data_type, cast_options), + _ => self.0.cast_with_options(dtype, cast_options), } } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index f5b704a3590a..a6a5f111d541 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -32,13 +32,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0 - .explode_by_offsets(offsets) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - } - #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); @@ -48,12 +41,16 @@ impl private::PrivateSeries for SeriesWrap { }) } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -141,14 +138,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -179,6 +176,10 @@ impl SeriesTrait for SeriesWrap { ) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } @@ -190,14 +191,14 @@ impl SeriesTrait for SeriesWrap { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let other = other.to_physical_repr(); - self.0.append(other.as_ref().as_ref().as_ref()); + self.0.append(other.as_ref().as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref().as_ref())?; Ok(()) } @@ -252,8 +253,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - match (data_type, self.0.time_unit()) { + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + match (dtype, self.0.time_unit()) { (DataType::String, TimeUnit::Milliseconds) => { Ok(self.0.to_string("%F %T%.3f")?.into_series()) }, @@ -263,7 +264,7 @@ impl SeriesTrait for SeriesWrap { (DataType::String, TimeUnit::Nanoseconds) => { Ok(self.0.to_string("%F %T%.9f")?.into_series()) }, - _ => self.0.cast_with_options(data_type, cast_options), + _ => self.0.cast_with_options(dtype, cast_options), } } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 144daabf2244..30125ccc15b6 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -54,20 +54,24 @@ impl SeriesWrap { let arr = ca.downcast_iter().next().unwrap(); // SAFETY: dtype is passed correctly let s = unsafe { - Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], dtype) + Series::from_chunks_and_dtype_unchecked( + PlSmallStr::EMPTY, + vec![arr.values().clone()], + dtype, + ) }; let new_values = s.array_ref(0).clone(); - let data_type = + let dtype = ListArray::::default_datatype(dtype.to_arrow(CompatLevel::newest())); let new_arr = ListArray::::new( - data_type, + dtype, arr.offsets().clone(), new_values, arr.validity().cloned(), ); unsafe { ListChunked::from_chunks_and_dtype_unchecked( - agg_s.name(), + agg_s.name().clone(), vec![Box::new(new_arr)], DataType::List(Box::new(self.dtype().clone())), ) @@ -126,12 +130,16 @@ impl private::PrivateSeries for SeriesWrap { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -176,21 +184,10 @@ impl private::PrivateSeries for SeriesWrap { fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } - - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0 - .explode_by_offsets(offsets) - .decimal() - .unwrap() - .as_ref() - .clone() - .into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() - } } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name) } @@ -198,7 +195,7 @@ impl SeriesTrait for SeriesWrap { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -227,14 +224,14 @@ impl SeriesTrait for SeriesWrap { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let other = other.decimal()?; - self.0.append(&other.0); + self.0.append(&other.0)?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.decimal()?; - self.0.extend(&other.0); + self.0.extend(&other.0)?; Ok(()) } @@ -293,8 +290,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, cast_options) + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) } fn get(&self, index: usize) -> PolarsResult { @@ -385,6 +382,10 @@ impl SeriesTrait for SeriesWrap { })) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() / self.scale_factor() as f64 + } + fn mean(&self) -> Option { self.0.mean().map(|v| v / self.scale_factor() as f64) } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 2e4565015621..73d2e4f730fb 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -29,13 +29,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.dtype() } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0 - .explode_by_offsets(offsets) - .into_duration(self.0.time_unit()) - .into_series() - } - fn _set_flags(&mut self, flags: MetadataFlags) { self.0.deref_mut().set_flags(flags) } @@ -55,12 +48,16 @@ impl private::PrivateSeries for SeriesWrap { .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -255,14 +252,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -291,6 +288,10 @@ impl SeriesTrait for SeriesWrap { (a, b) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } @@ -310,14 +311,14 @@ impl SeriesTrait for SeriesWrap { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); let other = other.to_physical_repr().into_owned(); - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref().as_ref())?; Ok(()) } @@ -375,8 +376,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, cast_options) + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 3f6ccbe1c3f9..cc52d73cdc60 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -14,7 +14,7 @@ macro_rules! impl_dyn_series { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _set_flags(&mut self, flags: MetadataFlags) { @@ -23,10 +23,6 @@ macro_rules! impl_dyn_series { fn _get_flags(&self) -> MetadataFlags { self.0.get_flags() } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } - unsafe fn equal_element( &self, idx_self: usize, @@ -52,14 +48,18 @@ macro_rules! impl_dyn_series { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash( + &self, + random_state: PlRandomState, + buf: &mut Vec, + ) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } fn vec_hash_combine( &self, - build_hasher: RandomState, + build_hasher: PlRandomState, hashes: &mut [u64], ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; @@ -169,14 +169,14 @@ macro_rules! impl_dyn_series { self.metadata_dyn() } - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -201,13 +201,13 @@ macro_rules! impl_dyn_series { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); - self.0.extend(other.as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref())?; Ok(()) } @@ -215,6 +215,10 @@ macro_rules! impl_dyn_series { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } @@ -259,12 +263,8 @@ macro_rules! impl_dyn_series { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast( - &self, - data_type: &DataType, - cast_options: CastOptions, - ) -> PolarsResult { - self.0.cast_with_options(data_type, cast_options) + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index c66fd2ce8492..5e5b4a95d5e2 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -12,7 +12,7 @@ impl private::PrivateSeries for SeriesWrap { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _get_flags(&self) -> MetadataFlags { self.0.get_flags() @@ -21,10 +21,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } - unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { self.0.equal_element(idx_self, idx_other, other) } @@ -68,14 +64,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -148,8 +144,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, cast_options) + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) } fn get(&self, index: usize) -> PolarsResult { @@ -209,13 +205,13 @@ impl SeriesTrait for SeriesWrap { } // this can be called in aggregation, so this fast path can be worth a lot if self.len() == 1 { - return Ok(IdxCa::new_vec(self.name(), vec![0 as IdxSize])); + return Ok(IdxCa::new_vec(self.name().clone(), vec![0 as IdxSize])); } let main_thread = POOL.current_thread_index().is_none(); // arg_unique requires a stable order let groups = self.group_tuples(main_thread, true)?; let first = groups.take_group_firsts(); - Ok(IdxCa::from_vec(self.name(), first)) + Ok(IdxCa::from_vec(self.name().clone(), first)) } fn is_null(&self) -> BooleanChunked { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index d50f84c2d532..3e4e41395b0b 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -20,7 +20,7 @@ pub(crate) mod null; mod object; mod string; #[cfg(feature = "dtype-struct")] -mod struct__; +mod struct_; #[cfg(feature = "dtype-time")] mod time; @@ -29,15 +29,12 @@ use std::borrow::Cow; use std::ops::{BitAnd, BitOr, BitXor}; use std::sync::RwLockReadGuard; -use ahash::RandomState; - use super::*; use crate::chunked_array::comparison::*; use crate::chunked_array::metadata::MetadataTrait; use crate::chunked_array::ops::compare_inner::{ IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, }; -use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; // Utility wrapper struct @@ -81,7 +78,7 @@ macro_rules! impl_dyn_series { } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _get_flags(&self) -> MetadataFlags { @@ -92,10 +89,6 @@ macro_rules! impl_dyn_series { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } - unsafe fn equal_element( &self, idx_self: usize, @@ -121,14 +114,18 @@ macro_rules! impl_dyn_series { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash( + &self, + random_state: PlRandomState, + buf: &mut Vec, + ) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } fn vec_hash_combine( &self, - build_hasher: RandomState, + build_hasher: PlRandomState, hashes: &mut [u64], ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; @@ -275,14 +272,14 @@ macro_rules! impl_dyn_series { Ok(self.0.bitxor(&other).into_series()) } - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -307,13 +304,13 @@ macro_rules! impl_dyn_series { fn append(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), append); - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); - self.0.extend(other.as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref())?; Ok(()) } @@ -321,6 +318,10 @@ macro_rules! impl_dyn_series { ChunkFilter::filter(&self.0, filter).map(|ca| ca.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } @@ -365,8 +366,8 @@ macro_rules! impl_dyn_series { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, options) + fn cast(&self, dtype: &DataType, options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index e935cb03ced8..34655a6ac61e 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -3,20 +3,19 @@ use std::any::Any; use polars_error::constants::LENGTH_LIMIT_MSG; use crate::prelude::compare_inner::{IntoTotalEqInner, TotalEqInner}; -use crate::prelude::explode::ExplodeByOffsets; use crate::prelude::*; use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; use crate::series::*; impl Series { - pub fn new_null(name: &str, len: usize) -> Series { - NullChunked::new(Arc::from(name), len).into_series() + pub fn new_null(name: PlSmallStr, len: usize) -> Series { + NullChunked::new(name, len).into_series() } } #[derive(Clone)] pub struct NullChunked { - pub(crate) name: Arc, + pub(crate) name: PlSmallStr, length: IdxSize, // we still need chunks as many series consumers expect // chunks to be there @@ -24,7 +23,7 @@ pub struct NullChunked { } impl NullChunked { - pub(crate) fn new(name: Arc, len: usize) -> Self { + pub(crate) fn new(name: PlSmallStr, len: usize) -> Self { Self { name, length: len as IdxSize, @@ -38,7 +37,7 @@ impl NullChunked { impl PrivateSeriesNumeric for NullChunked { fn bit_repr(&self) -> Option { Some(BitRepr::Small(UInt32Chunked::full_null( - self.name.as_ref(), + self.name.clone(), self.len(), ))) } @@ -56,7 +55,7 @@ impl PrivateSeries for NullChunked { self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); } fn _field(&self) -> Cow { - Cow::Owned(Field::new(self.name(), DataType::Null)) + Cow::Owned(Field::new(self.name().clone(), DataType::Null)) } #[allow(unused)] @@ -78,12 +77,8 @@ impl PrivateSeries for NullChunked { }, }; - Ok(Self::new(self.name().into(), len).into_series()) + Ok(Self::new(self.name().clone(), len).into_series()) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - ExplodeByOffsets::explode_by_offsets(self, offsets) - } - fn subtract(&self, _rhs: &Series) -> PolarsResult { null_arithmetic(self, _rhs, "subtract") } @@ -122,12 +117,16 @@ impl PrivateSeries for NullChunked { MetadataFlags::empty() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { VecHash::vec_hash(self, random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { VecHash::vec_hash_combine(self, build_hasher, hashes)?; Ok(()) } @@ -144,16 +143,16 @@ fn null_arithmetic(lhs: &NullChunked, rhs: &Series, op: &str) -> PolarsResult len_l, _ => polars_bail!(ComputeError: "Cannot {:?} two series of different lengths.", op), }; - Ok(NullChunked::new(lhs.name().into(), output_len).into_series()) + Ok(NullChunked::new(lhs.name().clone(), output_len).into_series()) } impl SeriesTrait for NullChunked { - fn name(&self) -> &str { - self.name.as_ref() + fn name(&self) -> &PlSmallStr { + &self.name } - fn rename(&mut self, name: &str) { - self.name = Arc::from(name) + fn rename(&mut self, name: PlSmallStr) { + self.name = name } fn chunks(&self) -> &Vec { @@ -199,8 +198,8 @@ impl SeriesTrait for NullChunked { NullChunked::new(self.name.clone(), 0).into_series() } - fn cast(&self, data_type: &DataType, _cast_options: CastOptions) -> PolarsResult { - Ok(Series::full_null(self.name.as_ref(), self.len(), data_type)) + fn cast(&self, dtype: &DataType, _cast_options: CastOptions) -> PolarsResult { + Ok(Series::full_null(self.name.clone(), self.len(), dtype)) } fn null_count(&self) -> usize { @@ -265,11 +264,11 @@ impl SeriesTrait for NullChunked { } fn is_null(&self) -> BooleanChunked { - BooleanChunked::full(self.name(), true, self.len()) + BooleanChunked::full(self.name().clone(), true, self.len()) } fn is_not_null(&self) -> BooleanChunked { - BooleanChunked::full(self.name(), false, self.len()) + BooleanChunked::full(self.name().clone(), false, self.len()) } fn reverse(&self) -> Series { diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index 8ba9eec2e8df..5ef70ed8c99c 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -1,8 +1,6 @@ use std::any::Any; use std::borrow::Cow; -use ahash::RandomState; - use super::{BitRepr, MetadataFlags}; use crate::chunked_array::cast::CastOptions; use crate::chunked_array::object::PolarsObjectSafe; @@ -23,7 +21,7 @@ where { fn get_list_builder( &self, - _name: &str, + _name: PlSmallStr, _values_capacity: usize, _list_capacity: usize, ) -> Box { @@ -56,12 +54,16 @@ where (&self.0).into_total_eq_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -81,7 +83,7 @@ impl SeriesTrait for SeriesWrap> where T: PolarsObject, { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { ObjectChunked::rename(&mut self.0, name) } @@ -89,7 +91,7 @@ where ObjectChunked::chunk_lengths(&self.0) } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { ObjectChunked::name(&self.0) } @@ -117,7 +119,7 @@ where if self.dtype() != other.dtype() { polars_bail!(append); } - ObjectChunked::append(&mut self.0, other.as_ref().as_ref()); + ObjectChunked::append(&mut self.0, other.as_ref().as_ref())?; Ok(()) } @@ -158,8 +160,8 @@ where ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, _cast_options: CastOptions) -> PolarsResult { - if matches!(data_type, DataType::Object(_, None)) { + fn cast(&self, dtype: &DataType, _cast_options: CastOptions) -> PolarsResult { + if matches!(dtype, DataType::Object(_, None)) { Ok(self.0.clone().into_series()) } else { Err(PolarsError::ComputeError( @@ -244,7 +246,7 @@ mod test { } } - let ca = ObjectChunked::new_from_vec("a", vec![0i32, 1, 2]); + let ca = ObjectChunked::new_from_vec("a".into(), vec![0i32, 1, 2]); let s = ca.into_series(); let ca = s.as_any().downcast_ref::>().unwrap(); diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 07e2ec5b14ad..c8d85825e84b 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -12,7 +12,7 @@ impl private::PrivateSeries for SeriesWrap { Cow::Borrowed(self.0.ref_field()) } fn _dtype(&self) -> &DataType { - self.0.ref_field().data_type() + self.0.ref_field().dtype() } fn _set_flags(&mut self, flags: MetadataFlags) { @@ -21,10 +21,6 @@ impl private::PrivateSeries for SeriesWrap { fn _get_flags(&self) -> MetadataFlags { self.0.get_flags() } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets) - } - unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { self.0.equal_element(idx_self, idx_other, other) } @@ -40,12 +36,16 @@ impl private::PrivateSeries for SeriesWrap { (&self.0).into_total_ord_inner() } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -95,14 +95,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -130,7 +130,7 @@ impl SeriesTrait for SeriesWrap { SchemaMismatch: "cannot extend Series: data types don't match", ); // todo! add object - self.0.append(other.as_ref().as_ref()); + self.0.append(other.as_ref().as_ref())?; Ok(()) } @@ -139,7 +139,7 @@ impl SeriesTrait for SeriesWrap { self.0.dtype() == other.dtype(), SchemaMismatch: "cannot extend Series: data types don't match", ); - self.0.extend(other.as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref())?; Ok(()) } @@ -175,8 +175,8 @@ impl SeriesTrait for SeriesWrap { ChunkExpandAtIndex::new_from_index(&self.0, index, length).into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - self.0.cast_with_options(data_type, cast_options) + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + self.0.cast_with_options(dtype, cast_options) } fn get(&self, index: usize) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/struct__.rs b/crates/polars-core/src/series/implementations/struct_.rs similarity index 86% rename from crates/polars-core/src/series/implementations/struct__.rs rename to crates/polars-core/src/series/implementations/struct_.rs index a6c775a4245d..805f06d86bac 100644 --- a/crates/polars-core/src/series/implementations/struct__.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -32,12 +32,6 @@ impl PrivateSeries for SeriesWrap { fn _set_flags(&mut self, _flags: MetadataFlags) {} - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self._apply_fields(|s| s.explode_by_offsets(offsets)) - .unwrap() - .into_series() - } - // TODO! remove this. Very slow. Asof join should use row-encoding. unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { let other = other.struct_().unwrap(); @@ -56,15 +50,9 @@ impl PrivateSeries for SeriesWrap { #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { - let other = other.struct_()?; - let fields = self - .0 - .fields_as_series() - .iter() - .zip(other.fields_as_series()) - .map(|(lhs, rhs)| lhs.zip_with_same_type(mask, &rhs)) - .collect::>>()?; - StructChunked::from_series(self.0.name(), &fields).map(|ca| ca.into_series()) + self.0 + .zip_with(mask, other.struct_()?) + .map(|ca| ca.into_series()) } #[cfg(feature = "algorithm_group_by")] @@ -72,7 +60,7 @@ impl PrivateSeries for SeriesWrap { self.0.agg_list(groups) } - fn vec_hash(&self, build_hasher: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, build_hasher: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { let mut fields = self.0.fields_as_series().into_iter(); if let Some(s) = fields.next() { @@ -86,7 +74,7 @@ impl PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name) } @@ -94,7 +82,7 @@ impl SeriesTrait for SeriesWrap { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -155,7 +143,7 @@ impl SeriesTrait for SeriesWrap { } fn new_from_index(&self, _index: usize, _length: usize) -> Series { - self.0.new_from_index(_length, _index).into_series() + self.0.new_from_index(_index, _length).into_series() } fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { @@ -209,12 +197,12 @@ impl SeriesTrait for SeriesWrap { fn arg_unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot if self.len() == 1 { - return Ok(IdxCa::new_vec(self.name(), vec![0 as IdxSize])); + return Ok(IdxCa::new_vec(self.name().clone(), vec![0 as IdxSize])); } let main_thread = POOL.current_thread_index().is_none(); let groups = self.group_tuples(main_thread, true)?; let first = groups.take_group_firsts(); - Ok(IdxCa::from_vec(self.name(), first)) + Ok(IdxCa::from_vec(self.name().clone(), first)) } fn has_nulls(&self) -> bool { @@ -229,7 +217,7 @@ impl SeriesTrait for SeriesWrap { }; BooleanArray::from_data_default(bitmap, None) }); - BooleanChunked::from_chunk_iter(self.name(), iter) + BooleanChunked::from_chunk_iter(self.name().clone(), iter) } fn is_not_null(&self) -> BooleanChunked { @@ -240,7 +228,7 @@ impl SeriesTrait for SeriesWrap { }; BooleanArray::from_data_default(bitmap, None) }); - BooleanChunked::from_chunk_iter(self.name(), iter) + BooleanChunked::from_chunk_iter(self.name().clone(), iter) } fn reverse(&self) -> Series { @@ -248,10 +236,7 @@ impl SeriesTrait for SeriesWrap { } fn shift(&self, periods: i64) -> Series { - self.0 - ._apply_fields(|s| s.shift(periods)) - .unwrap() - .into_series() + self.0.shift(periods).into_series() } fn clone_inner(&self) -> Arc { diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 5f4df072b30c..3808f7d977af 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -39,10 +39,6 @@ impl private::PrivateSeries for SeriesWrap { self.0.set_flags(flags) } - fn explode_by_offsets(&self, offsets: &[i64]) -> Series { - self.0.explode_by_offsets(offsets).into_time().into_series() - } - #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); @@ -51,12 +47,16 @@ impl private::PrivateSeries for SeriesWrap { .map(|ca| ca.into_time().into_series()) } - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, random_state: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { self.0.vec_hash(random_state, buf)?; Ok(()) } - fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + fn vec_hash_combine( + &self, + build_hasher: PlRandomState, + hashes: &mut [u64], + ) -> PolarsResult<()> { self.0.vec_hash_combine(build_hasher, hashes)?; Ok(()) } @@ -115,14 +115,14 @@ impl private::PrivateSeries for SeriesWrap { } impl SeriesTrait for SeriesWrap { - fn rename(&mut self, name: &str) { + fn rename(&mut self, name: PlSmallStr) { self.0.rename(name); } fn chunk_lengths(&self) -> ChunkLenIter { self.0.chunk_lengths() } - fn name(&self) -> &str { + fn name(&self) -> &PlSmallStr { self.0.name() } @@ -145,6 +145,10 @@ impl SeriesTrait for SeriesWrap { (a.into_series(), b.into_series()) } + fn _sum_as_f64(&self) -> f64 { + self.0._sum_as_f64() + } + fn mean(&self) -> Option { self.0.mean() } @@ -160,7 +164,7 @@ impl SeriesTrait for SeriesWrap { // ref Cow // ref SeriesTrait // ref ChunkedArray - self.0.append(other.as_ref().as_ref().as_ref()); + self.0.append(other.as_ref().as_ref().as_ref())?; Ok(()) } @@ -171,7 +175,7 @@ impl SeriesTrait for SeriesWrap { // ref SeriesTrait // ref ChunkedArray let other = other.to_physical_repr(); - self.0.extend(other.as_ref().as_ref().as_ref()); + self.0.extend(other.as_ref().as_ref().as_ref())?; Ok(()) } @@ -210,8 +214,8 @@ impl SeriesTrait for SeriesWrap { .into_series() } - fn cast(&self, data_type: &DataType, cast_options: CastOptions) -> PolarsResult { - match data_type { + fn cast(&self, dtype: &DataType, cast_options: CastOptions) -> PolarsResult { + match dtype { DataType::String => Ok(self .0 .clone() @@ -220,7 +224,7 @@ impl SeriesTrait for SeriesWrap { .unwrap() .to_string("%T") .into_series()), - _ => self.0.cast_with_options(data_type, cast_options), + _ => self.0.cast_with_options(dtype, cast_options), } } diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs index d1f722a9bd7e..1213c3346525 100644 --- a/crates/polars-core/src/series/into.rs +++ b/crates/polars-core/src/series/into.rs @@ -34,7 +34,7 @@ impl Series { let dtype = &field.dtype; let s = unsafe { Series::from_chunks_and_dtype_unchecked( - "", + PlSmallStr::EMPTY, vec![values.clone()], &dtype.to_physical(), ) @@ -59,7 +59,7 @@ impl Series { // We pass physical arrays and cast to logical before we convert to arrow. let s = unsafe { Series::from_chunks_and_dtype_unchecked( - "", + PlSmallStr::EMPTY, vec![arr.values().clone()], &inner.to_physical(), ) @@ -70,9 +70,9 @@ impl Series { s.to_arrow(0, compat_level) }; - let data_type = ListArray::::default_datatype(inner.to_arrow(compat_level)); + let dtype = ListArray::::default_datatype(inner.to_arrow(compat_level)); let arr = ListArray::::new( - data_type, + dtype, arr.offsets().clone(), new_values, arr.validity().cloned(), @@ -84,7 +84,7 @@ impl Series { let ca = self.categorical().unwrap(); let arr = ca.physical().chunks()[chunk_idx].clone(); // SAFETY: categoricals are always u32's. - let cats = unsafe { UInt32Chunked::from_chunks("", vec![arr]) }; + let cats = unsafe { UInt32Chunked::from_chunks(PlSmallStr::EMPTY, vec![arr]) }; // SAFETY: we only take a single chunk and change nothing about the index/rev_map mapping. let new = unsafe { diff --git a/crates/polars-core/src/series/iterator.rs b/crates/polars-core/src/series/iterator.rs index b8d6385cdbe8..d4dc5df63ccb 100644 --- a/crates/polars-core/src/series/iterator.rs +++ b/crates/polars-core/src/series/iterator.rs @@ -43,6 +43,13 @@ from_iterator!(f32, Float32Type); from_iterator!(f64, Float64Type); from_iterator!(bool, BooleanType); +impl<'a> FromIterator> for Series { + fn from_iter>>(iter: I) -> Self { + let ca: StringChunked = iter.into_iter().collect(); + ca.into_series() + } +} + impl<'a> FromIterator<&'a str> for Series { fn from_iter>(iter: I) -> Self { let ca: StringChunked = iter.into_iter().collect(); @@ -50,6 +57,13 @@ impl<'a> FromIterator<&'a str> for Series { } } +impl FromIterator> for Series { + fn from_iter>>(iter: T) -> Self { + let ca: StringChunked = iter.into_iter().collect(); + ca.into_series() + } +} + impl FromIterator for Series { fn from_iter>(iter: I) -> Self { let ca: StringChunked = iter.into_iter().collect(); @@ -186,11 +200,27 @@ mod test { #[test] fn test_iter() { - let a = Series::new("age", [23, 71, 9].as_ref()); + let a = Series::new("age".into(), [23, 71, 9].as_ref()); let _b = a .i32() .unwrap() .into_iter() .map(|opt_v| opt_v.map(|v| v * 2)); } + + #[test] + fn test_iter_str() { + let data = [Some("John"), Some("Doe"), None]; + let a: Series = data.into_iter().collect(); + let b = Series::new("".into(), data); + assert_eq!(a, b); + } + + #[test] + fn test_iter_string() { + let data = [Some("John".to_string()), Some("Doe".to_string()), None]; + let a: Series = data.clone().into_iter().collect(); + let b = Series::new("".into(), data); + assert_eq!(a, b); + } } diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index f3e15779100a..a629a8fd1c5c 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -17,7 +17,6 @@ use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use ahash::RandomState; use arrow::compute::aggregate::estimated_bytes_size; use arrow::offset::Offsets; pub use from::*; @@ -53,7 +52,7 @@ use crate::POOL; /// You can do standard arithmetic on series. /// ``` /// # use polars_core::prelude::*; -/// let s = Series::new("a", [1 , 2, 3]); +/// let s = Series::new("a".into(), [1 , 2, 3]); /// let out_add = &s + &s; /// let out_sub = &s - &s; /// let out_div = &s / &s; @@ -80,7 +79,7 @@ use crate::POOL; /// /// ``` /// # use polars_core::prelude::*; -/// let s = Series::new("dollars", &[1, 2, 3]); +/// let s = Series::new("dollars".into(), &[1, 2, 3]); /// let mask = s.equal(1).unwrap(); /// let valid = [true, false, false].iter(); /// assert!(mask @@ -102,7 +101,7 @@ use crate::POOL; /// ``` /// use polars_core::prelude::*; /// let pi = 3.14; -/// let s = Series::new("angle", [2f32 * pi, pi, 1.5 * pi].as_ref()); +/// let s = Series::new("angle".into(), [2f32 * pi, pi, 1.5 * pi].as_ref()); /// let s_cos: Series = s.f32() /// .expect("series was not an f32 dtype") /// .into_iter() @@ -117,10 +116,10 @@ use crate::POOL; /// ``` /// # use polars_core::prelude::*; /// // Series can be created from Vec's, slices and arrays -/// Series::new("boolean series", &[true, false, true]); -/// Series::new("int series", &[1, 2, 3]); +/// Series::new("boolean series".into(), &[true, false, true]); +/// Series::new("int series".into(), &[1, 2, 3]); /// // And can be nullable -/// Series::new("got nulls", &[Some(1), None, Some(2)]); +/// Series::new("got nulls".into(), &[Some(1), None, Some(2)]); /// /// // Series can also be collected from iterators /// let from_iter: Series = (0..10) @@ -142,7 +141,7 @@ impl Eq for Wrap {} impl Hash for Wrap { fn hash(&self, state: &mut H) { - let rs = RandomState::with_seeds(0, 0, 0, 0); + let rs = PlRandomState::with_seeds(0, 0, 0, 0); let mut h = vec![]; if self.0.vec_hash(rs, &mut h).is_ok() { let h = h.into_iter().fold(0, |a: u64, b| a.wrapping_add(b)); @@ -157,7 +156,7 @@ impl Hash for Wrap { impl Series { /// Create a new empty Series. - pub fn new_empty(name: &str, dtype: &DataType) -> Series { + pub fn new_empty(name: PlSmallStr, dtype: &DataType) -> Series { Series::full_null(name, 0, dtype) } @@ -168,9 +167,9 @@ impl Series { match self.dtype() { #[cfg(feature = "object")] DataType::Object(_, _) => self - .take(&ChunkedArray::::new_vec("", vec![])) + .take(&ChunkedArray::::new_vec(PlSmallStr::EMPTY, vec![])) .unwrap(), - dt => Series::new_empty(self.name(), dt), + dt => Series::new_empty(self.name().clone(), dt), } } } @@ -260,18 +259,18 @@ impl Series { } /// Rename series. - pub fn rename(&mut self, name: &str) -> &mut Series { + pub fn rename(&mut self, name: PlSmallStr) -> &mut Series { self._get_inner_mut().rename(name); self } /// Return this Series with a new name. - pub fn with_name(mut self, name: &str) -> Series { + pub fn with_name(mut self, name: PlSmallStr) -> Series { self.rename(name); self } - /// Try to set the [`Metadata`] for the underlying [`ChunkedArray`] + /// to set the [`Metadata`] for the underlying [`ChunkedArray`] /// /// This does not guarantee that the [`Metadata`] is always set. It returns whether it was /// successful. @@ -289,16 +288,16 @@ impl Series { true } - pub fn from_arrow_chunks(name: &str, arrays: Vec) -> PolarsResult { + pub fn from_arrow_chunks(name: PlSmallStr, arrays: Vec) -> PolarsResult { Self::try_from((name, arrays)) } - pub fn from_arrow(name: &str, array: ArrayRef) -> PolarsResult { + pub fn from_arrow(name: PlSmallStr, array: ArrayRef) -> PolarsResult { Self::try_from((name, array)) } #[cfg(feature = "arrow_rs")] - pub fn from_arrow_rs(name: &str, array: &dyn arrow_array::Array) -> PolarsResult { + pub fn from_arrow_rs(name: PlSmallStr, array: &dyn arrow_array::Array) -> PolarsResult { Self::from_arrow(name, array.into()) } @@ -347,9 +346,9 @@ impl Series { /// ```rust /// # use polars_core::prelude::*; /// # fn main() -> PolarsResult<()> { - /// let s = Series::new("foo", [2, 1, 3]); + /// let s = Series::new("foo".into(), [2, 1, 3]); /// let sorted = s.sort(SortOptions::default())?; - /// assert_eq!(sorted, Series::new("foo", [1, 2, 3])); + /// assert_eq!(sorted, Series::new("foo".into(), [1, 2, 3])); /// # Ok(()) /// } /// ``` @@ -438,7 +437,7 @@ impl Series { // Always allow casting all nulls to other all nulls. let len = self.len(); if self.null_count() == len { - return Ok(Series::full_null(self.name(), len, dtype)); + return Ok(Series::full_null(self.name().clone(), len, dtype)); } let new_options = match options { @@ -541,7 +540,9 @@ impl Series { match self.dtype() { DataType::Float32 => Ok(self.f32().unwrap().is_nan()), DataType::Float64 => Ok(self.f64().unwrap().is_nan()), - dt if dt.is_numeric() => Ok(BooleanChunked::full(self.name(), false, self.len())), + dt if dt.is_numeric() => { + Ok(BooleanChunked::full(self.name().clone(), false, self.len())) + }, _ => polars_bail!(opq = is_nan, self.dtype()), } } @@ -551,7 +552,9 @@ impl Series { match self.dtype() { DataType::Float32 => Ok(self.f32().unwrap().is_not_nan()), DataType::Float64 => Ok(self.f64().unwrap().is_not_nan()), - dt if dt.is_numeric() => Ok(BooleanChunked::full(self.name(), true, self.len())), + dt if dt.is_numeric() => { + Ok(BooleanChunked::full(self.name().clone(), true, self.len())) + }, _ => polars_bail!(opq = is_not_nan, self.dtype()), } } @@ -561,7 +564,9 @@ impl Series { match self.dtype() { DataType::Float32 => Ok(self.f32().unwrap().is_finite()), DataType::Float64 => Ok(self.f64().unwrap().is_finite()), - dt if dt.is_numeric() => Ok(BooleanChunked::full(self.name(), true, self.len())), + dt if dt.is_numeric() => { + Ok(BooleanChunked::full(self.name().clone(), true, self.len())) + }, _ => polars_bail!(opq = is_finite, self.dtype()), } } @@ -571,7 +576,9 @@ impl Series { match self.dtype() { DataType::Float32 => Ok(self.f32().unwrap().is_infinite()), DataType::Float64 => Ok(self.f64().unwrap().is_infinite()), - dt if dt.is_numeric() => Ok(BooleanChunked::full(self.name(), false, self.len())), + dt if dt.is_numeric() => { + Ok(BooleanChunked::full(self.name().clone(), false, self.len())) + }, _ => polars_bail!(opq = is_infinite, self.dtype()), } } @@ -621,7 +628,7 @@ impl Series { .iter() .map(|s| s.to_physical_repr().into_owned()) .collect(); - let mut ca = StructChunked::from_series(self.name(), &fields).unwrap(); + let mut ca = StructChunked::from_series(self.name().clone(), &fields).unwrap(); if arr.null_count() > 0 { ca.zip_outer_validity(arr); @@ -644,7 +651,7 @@ impl Series { pub fn gather_every(&self, n: usize, offset: usize) -> Series { let idx = ((offset as IdxSize)..self.len() as IdxSize) .step_by(n) - .collect_ca(""); + .collect_ca(PlSmallStr::EMPTY); // SAFETY: we stay in-bounds. unsafe { self.take_unchecked(&idx) } } @@ -889,11 +896,11 @@ impl Series { let offsets = (0i64..(s.len() as i64 + 1)).collect::>(); let offsets = unsafe { Offsets::new_unchecked(offsets) }; - let data_type = LargeListArray::default_datatype( + let dtype = LargeListArray::default_datatype( s.dtype().to_physical().to_arrow(CompatLevel::newest()), ); - let new_arr = LargeListArray::new(data_type, offsets.into(), values, None); - let mut out = ListChunked::with_chunk(s.name(), new_arr); + let new_arr = LargeListArray::new(dtype, offsets.into(), values, None); + let mut out = ListChunked::with_chunk(s.name().clone(), new_arr); out.set_inner_dtype(s.dtype().clone()); out } @@ -970,7 +977,7 @@ mod test { #[test] fn cast() { - let ar = UInt32Chunked::new("a", &[1, 2]); + let ar = UInt32Chunked::new("a".into(), &[1, 2]); let s = ar.into_series(); let s2 = s.cast(&DataType::Int64).unwrap(); @@ -981,9 +988,9 @@ mod test { #[test] fn new_series() { - let _ = Series::new("boolean series", &vec![true, false, true]); - let _ = Series::new("int series", &[1, 2, 3]); - let ca = Int32Chunked::new("a", &[1, 2, 3]); + let _ = Series::new("boolean series".into(), &vec![true, false, true]); + let _ = Series::new("int series".into(), &[1, 2, 3]); + let ca = Int32Chunked::new("a".into(), &[1, 2, 3]); let _ = ca.into_series(); } @@ -992,7 +999,7 @@ mod test { fn new_series_from_empty_structs() { let dtype = DataType::Struct(vec![]); let empties = vec![AnyValue::StructOwned(Box::new((vec![], vec![]))); 3]; - let s = Series::from_any_values_and_dtype("", &empties, &dtype, false).unwrap(); + let s = Series::from_any_values_and_dtype("".into(), &empties, &dtype, false).unwrap(); assert_eq!(s.len(), 3); } #[test] @@ -1000,28 +1007,28 @@ mod test { let array = UInt32Array::from_slice([1, 2, 3, 4, 5]); let array_ref: ArrayRef = Box::new(array); - let _ = Series::try_from(("foo", array_ref)).unwrap(); + let _ = Series::try_new("foo".into(), array_ref).unwrap(); } #[test] fn series_append() { - let mut s1 = Series::new("a", &[1, 2]); - let s2 = Series::new("b", &[3]); + let mut s1 = Series::new("a".into(), &[1, 2]); + let s2 = Series::new("b".into(), &[3]); s1.append(&s2).unwrap(); assert_eq!(s1.len(), 3); // add wrong type - let s2 = Series::new("b", &[3.0]); + let s2 = Series::new("b".into(), &[3.0]); assert!(s1.append(&s2).is_err()) } #[test] #[cfg(feature = "dtype-decimal")] fn series_append_decimal() { - let s1 = Series::new("a", &[1.1, 2.3]) + let s1 = Series::new("a".into(), &[1.1, 2.3]) .cast(&DataType::Decimal(None, Some(2))) .unwrap(); - let s2 = Series::new("b", &[3]) + let s2 = Series::new("b".into(), &[3]) .cast(&DataType::Decimal(None, Some(0))) .unwrap(); @@ -1041,7 +1048,7 @@ mod test { #[test] fn series_slice_works() { - let series = Series::new("a", &[1i64, 2, 3, 4, 5]); + let series = Series::new("a".into(), &[1i64, 2, 3, 4, 5]); let slice_1 = series.slice(-3, 3); let slice_2 = series.slice(-5, 5); @@ -1054,7 +1061,7 @@ mod test { #[test] fn out_of_range_slice_does_not_panic() { - let series = Series::new("a", &[1i64, 2, 3, 4, 5]); + let series = Series::new("a".into(), &[1i64, 2, 3, 4, 5]); let _ = series.slice(-3, 4); let _ = series.slice(-6, 2); diff --git a/crates/polars-core/src/series/ops/downcast.rs b/crates/polars-core/src/series/ops/downcast.rs index 6441dfe03df4..ce57e42c610c 100644 --- a/crates/polars-core/src/series/ops/downcast.rs +++ b/crates/polars-core/src/series/ops/downcast.rs @@ -36,7 +36,7 @@ impl Series { /// Unpack to [`ChunkedArray`] /// ``` /// # use polars_core::prelude::*; - /// let s = Series::new("foo", [1i32 ,2, 3]); + /// let s = Series::new("foo".into(), [1i32 ,2, 3]); /// let s_squared: Series = s.i32() /// .unwrap() /// .into_iter() diff --git a/crates/polars-core/src/series/ops/extend.rs b/crates/polars-core/src/series/ops/extend.rs index 08a196335f4c..8bb72d515d59 100644 --- a/crates/polars-core/src/series/ops/extend.rs +++ b/crates/polars-core/src/series/ops/extend.rs @@ -4,7 +4,7 @@ impl Series { /// Extend with a constant value. pub fn extend_constant(&self, value: AnyValue, n: usize) -> PolarsResult { // TODO: Use `from_any_values_and_dtype` here instead of casting afterwards - let s = Series::from_any_values("", &[value], true).unwrap(); + let s = Series::from_any_values(PlSmallStr::EMPTY, &[value], true).unwrap(); let s = s.cast(self.dtype())?; let to_append = s.new_from_index(0, n); diff --git a/crates/polars-core/src/series/ops/null.rs b/crates/polars-core/src/series/ops/null.rs index d13ce699cbad..ee33c309687e 100644 --- a/crates/polars-core/src/series/ops/null.rs +++ b/crates/polars-core/src/series/ops/null.rs @@ -1,9 +1,11 @@ +use arrow::bitmap::Bitmap; + #[cfg(feature = "object")] use crate::chunked_array::object::registry::get_object_builder; use crate::prelude::*; impl Series { - pub fn full_null(name: &str, size: usize, dtype: &DataType) -> Self { + pub fn full_null(name: PlSmallStr, size: usize, dtype: &DataType) -> Self { // match the logical types and create them match dtype { DataType::List(inner_dtype) => { @@ -51,11 +53,16 @@ impl Series { DataType::Struct(fields) => { let fields = fields .iter() - .map(|fld| Series::full_null(fld.name(), size, fld.data_type())) + .map(|fld| Series::full_null(fld.name().clone(), size, fld.dtype())) .collect::>(); - StructChunked::from_series(name, &fields) - .unwrap() - .into_series() + let ca = StructChunked::from_series(name, &fields).unwrap(); + + if !fields.is_empty() { + ca.with_outer_validity(Some(Bitmap::new_zeroed(size))) + .into_series() + } else { + ca.into_series() + } }, DataType::Null => Series::new_null(name, size), DataType::Unknown(kind) => { diff --git a/crates/polars-core/src/series/ops/reshape.rs b/crates/polars-core/src/series/ops/reshape.rs index 550a12f54829..85c8e283e166 100644 --- a/crates/polars-core/src/series/ops/reshape.rs +++ b/crates/polars-core/src/series/ops/reshape.rs @@ -15,7 +15,7 @@ use crate::chunked_array::builder::get_list_builder; use crate::datatypes::{DataType, ListChunked}; use crate::prelude::{IntoSeries, Series, *}; -fn reshape_fast_path(name: &str, s: &Series) -> Series { +fn reshape_fast_path(name: PlSmallStr, s: &Series) -> Series { let mut ca = match s.dtype() { #[cfg(feature = "dtype-struct")] DataType::Struct(_) => { @@ -44,7 +44,7 @@ impl Series { .map(|arr| arr.values().clone()) .collect::>(); // Safety: guarded by the type system - unsafe { Series::from_chunks_and_dtype_unchecked(s.name(), chunks, dtype) } + unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) } .get_leaf_array() }, DataType::List(dtype) => { @@ -54,7 +54,7 @@ impl Series { .map(|arr| arr.values().clone()) .collect::>(); // Safety: guarded by the type system - unsafe { Series::from_chunks_and_dtype_unchecked(s.name(), chunks, dtype) } + unsafe { Series::from_chunks_and_dtype_unchecked(s.name().clone(), chunks, dtype) } .get_leaf_array() }, _ => s.clone(), @@ -71,19 +71,19 @@ impl Series { let offsets = vec![0i64, values.len() as i64]; let inner_type = s.dtype(); - let data_type = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); // SAFETY: offsets are correct. let arr = unsafe { ListArray::new( - data_type, + dtype, Offsets::new_unchecked(offsets).into(), values.clone(), None, ) }; - let mut ca = ListChunked::with_chunk(s.name(), arr); + let mut ca = ListChunked::with_chunk(s.name().clone(), arr); unsafe { ca.to_logical(inner_type.clone()) }; ca.set_fast_explode(); Ok(ca) @@ -165,7 +165,7 @@ impl Series { } Ok(unsafe { Series::from_chunks_and_dtype_unchecked( - leaf_array.name(), + leaf_array.name().clone(), vec![prev_array], &prev_dtype, ) @@ -203,7 +203,7 @@ impl Series { if s_ref.len() == 0_usize { if (rows == -1 || rows == 0) && (cols == -1 || cols == 0 || cols == 1) { - let s = reshape_fast_path(s.name(), s_ref); + let s = reshape_fast_path(s.name().clone(), s_ref); return Ok(s); } else { polars_bail!(InvalidOperation: "cannot reshape len 0 into shape {:?}", dimensions,) @@ -222,7 +222,7 @@ impl Series { // Fast path, we can create a unit list so we only allocate offsets. if rows as usize == s_ref.len() && cols == 1 { - let s = reshape_fast_path(s.name(), s_ref); + let s = reshape_fast_path(s.name().clone(), s_ref); return Ok(s); } @@ -232,7 +232,7 @@ impl Series { ); let mut builder = - get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name())?; + get_list_builder(s_ref.dtype(), s_ref.len(), rows as usize, s.name().clone())?; let mut offset = 0i64; for _ in 0..rows { @@ -256,9 +256,9 @@ mod test { #[test] fn test_to_list() -> PolarsResult<()> { - let s = Series::new("a", &[1, 2, 3]); + let s = Series::new("a".into(), &[1, 2, 3]); - let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name())?; + let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone())?; builder.append_series(&s).unwrap(); let expected = builder.finish(); @@ -270,7 +270,7 @@ mod test { #[test] fn test_reshape() -> PolarsResult<()> { - let s = Series::new("a", &[1, 2, 3, 4]); + let s = Series::new("a".into(), &[1, 2, 3, 4]); for (dims, list_len) in [ (&[-1, 1], 4), diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 03a9384d4f50..b5b60c5eff33 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -46,8 +46,6 @@ pub enum BitRepr { } pub(crate) mod private { - use ahash::RandomState; - use super::*; use crate::chunked_array::metadata::MetadataFlags; use crate::chunked_array::ops::compare_inner::{TotalEqInner, TotalOrdInner}; @@ -63,7 +61,7 @@ pub(crate) mod private { #[cfg(feature = "object")] fn get_list_builder( &self, - _name: &str, + _name: PlSmallStr, _values_capacity: usize, _list_capacity: usize, ) -> Box { @@ -81,10 +79,6 @@ pub(crate) mod private { fn _set_flags(&mut self, flags: MetadataFlags); - fn explode_by_offsets(&self, _offsets: &[i64]) -> Series { - invalid_operation_panic!(explode_by_offsets, self) - } - unsafe fn equal_element( &self, _idx_self: usize, @@ -101,41 +95,41 @@ pub(crate) mod private { fn into_total_ord_inner<'a>(&'a self) -> Box { invalid_operation_panic!(into_total_ord_inner, self) } - fn vec_hash(&self, _build_hasher: RandomState, _buf: &mut Vec) -> PolarsResult<()> { + fn vec_hash(&self, _build_hasher: PlRandomState, _buf: &mut Vec) -> PolarsResult<()> { polars_bail!(opq = vec_hash, self._dtype()); } fn vec_hash_combine( &self, - _build_hasher: RandomState, + _build_hasher: PlRandomState, _hashes: &mut [u64], ) -> PolarsResult<()> { polars_bail!(opq = vec_hash_combine, self._dtype()); } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { - Series::full_null(self._field().name(), groups.len(), self._dtype()) + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { - Series::full_null(self._field().name(), groups.len(), self._dtype()) + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { - Series::full_null(self._field().name(), groups.len(), self._dtype()) + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { - Series::full_null(self._field().name(), groups.len(), self._dtype()) + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { - Series::full_null(self._field().name(), groups.len(), self._dtype()) + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { - Series::full_null(self._field().name(), groups.len(), self._dtype()) + Series::full_null(self._field().name().clone(), groups.len(), self._dtype()) } fn subtract(&self, _rhs: &Series) -> PolarsResult { @@ -181,7 +175,7 @@ pub trait SeriesTrait: Send + Sync + private::PrivateSeries + private::PrivateSeriesNumeric { /// Rename the Series. - fn rename(&mut self, name: &str); + fn rename(&mut self, name: PlSmallStr); fn bitand(&self, _other: &Series) -> PolarsResult { polars_bail!(opq = bitand, self._dtype()); @@ -203,7 +197,7 @@ pub trait SeriesTrait: fn chunk_lengths(&self) -> ChunkLenIter; /// Name of series. - fn name(&self) -> &str; + fn name(&self) -> &PlSmallStr; /// Get field (used in schema) fn field(&self) -> Cow { @@ -298,6 +292,11 @@ pub trait SeriesTrait: } } + /// Returns the sum of the array as an f64. + fn _sum_as_f64(&self) -> f64 { + invalid_operation_panic!(_sum_as_f64, self) + } + /// Returns the mean value in the array /// Returns an option because the array is nullable. fn mean(&self) -> Option { @@ -328,13 +327,13 @@ pub trait SeriesTrait: /// /// ```rust /// use polars_core::prelude::*; - /// let s = Series::new("a", [0i32, 1, 8]); + /// let s = Series::new("a".into(), [0i32, 1, 8]); /// let s2 = s.new_from_index(2, 4); /// assert_eq!(Vec::from(s2.i32().unwrap()), &[Some(8), Some(8), Some(8), Some(8)]) /// ``` fn new_from_index(&self, _index: usize, _length: usize) -> Series; - fn cast(&self, _data_type: &DataType, options: CastOptions) -> PolarsResult; + fn cast(&self, _dtype: &DataType, options: CastOptions) -> PolarsResult; /// Get a single value by index. Don't use this operation for loops as a runtime cast is /// needed for every iteration. @@ -410,7 +409,7 @@ pub trait SeriesTrait: /// ```rust /// # use polars_core::prelude::*; /// fn example() -> PolarsResult<()> { - /// let s = Series::new("series", &[1, 2, 3]); + /// let s = Series::new("series".into(), &[1, 2, 3]); /// /// let shifted = s.shift(1); /// assert_eq!(Vec::from(shifted.i32()?), &[None, Some(1), Some(2)]); diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs index cb9f6e5389ab..bf056b5f7769 100644 --- a/crates/polars-core/src/testing.rs +++ b/crates/polars-core/src/testing.rs @@ -162,7 +162,9 @@ impl PartialEq for DataFrame { } /// Asserts that two expressions of type [`DataFrame`] are equal according to [`DataFrame::equals`] -/// at runtime. If the expression are not equal, the program will panic with a message that displays +/// at runtime. +/// +/// If the expression are not equal, the program will panic with a message that displays /// both dataframes. #[macro_export] macro_rules! assert_df_eq { @@ -179,26 +181,26 @@ mod test { #[test] fn test_series_equals() { - let a = Series::new("a", &[1_u32, 2, 3]); - let b = Series::new("a", &[1_u32, 2, 3]); + let a = Series::new("a".into(), &[1_u32, 2, 3]); + let b = Series::new("a".into(), &[1_u32, 2, 3]); assert!(a.equals(&b)); - let s = Series::new("foo", &[None, Some(1i64)]); + let s = Series::new("foo".into(), &[None, Some(1i64)]); assert!(s.equals_missing(&s)); } #[test] fn test_series_dtype_not_equal() { - let s_i32 = Series::new("a", &[1_i32, 2_i32]); - let s_i64 = Series::new("a", &[1_i64, 2_i64]); + let s_i32 = Series::new("a".into(), &[1_i32, 2_i32]); + let s_i64 = Series::new("a".into(), &[1_i64, 2_i64]); assert!(s_i32.dtype() != s_i64.dtype()); assert!(s_i32.equals(&s_i64)); } #[test] fn test_df_equal() { - let a = Series::new("a", [1, 2, 3].as_ref()); - let b = Series::new("b", [1, 2, 3].as_ref()); + let a = Series::new("a".into(), [1, 2, 3].as_ref()); + let b = Series::new("b".into(), [1, 2, 3].as_ref()); let df1 = DataFrame::new(vec![a, b]).unwrap(); assert!(df1.equals(&df1)) diff --git a/crates/polars-core/src/tests.rs b/crates/polars-core/src/tests.rs index 12e2701bb836..e8a8111225b7 100644 --- a/crates/polars-core/src/tests.rs +++ b/crates/polars-core/src/tests.rs @@ -4,9 +4,9 @@ use crate::prelude::*; fn test_initial_empty_sort() -> PolarsResult<()> { // https://github.com/pola-rs/polars/issues/1396 let data = vec![1.3; 42]; - let mut series = Series::new("data", Vec::::new()); - let series2 = Series::new("data2", data.clone()); - let series3 = Series::new("data3", data); + let mut series = Series::new("data".into(), Vec::::new()); + let series2 = Series::new("data2".into(), data.clone()); + let series3 = Series::new("data3".into(), data); let df = DataFrame::new(vec![series2, series3])?; for column in df.get_columns().iter() { diff --git a/crates/polars-core/src/utils/flatten.rs b/crates/polars-core/src/utils/flatten.rs index a3cd58c79c92..52b1c69ea6d9 100644 --- a/crates/polars-core/src/utils/flatten.rs +++ b/crates/polars-core/src/utils/flatten.rs @@ -12,7 +12,7 @@ pub fn flatten_df_iter(df: &DataFrame) -> impl Iterator + '_ { // SAFETY: // datatypes are correct let mut out = unsafe { - Series::from_chunks_and_dtype_unchecked(s.name(), vec![arr], s.dtype()) + Series::from_chunks_and_dtype_unchecked(s.name().clone(), vec![arr], s.dtype()) }; out.set_sorted_flag(s.is_sorted_flag()); out @@ -33,7 +33,9 @@ pub fn flatten_series(s: &Series) -> Vec { unsafe { s.chunks() .iter() - .map(|arr| Series::from_chunks_and_dtype_unchecked(name, vec![arr.clone()], dtype)) + .map(|arr| { + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![arr.clone()], dtype) + }) .collect() } } diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 1b07c8206b99..a516626e1abb 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -18,7 +18,6 @@ use num_traits::{One, Zero}; use rayon::prelude::*; pub use schema::*; pub use series::*; -use smartstring::alias::String as SmartString; pub use supertype::*; pub use {arrow, rayon}; @@ -40,7 +39,8 @@ pub fn _set_partition_size() -> usize { POOL.current_num_threads() } -/// Just a wrapper structure. Useful for certain impl specializations +/// Just a wrapper structure which is useful for certain impl specializations. +/// /// This is for instance use to implement /// `impl FromIterator for NoNull>` /// as `Option` was already implemented: @@ -160,7 +160,7 @@ impl Container for ChunkedArray { fn iter_chunks(&self) -> impl Iterator { self.downcast_iter() - .map(|arr| Self::with_chunk(self.name(), arr.clone())) + .map(|arr| Self::with_chunk(self.name().clone(), arr.clone())) } fn n_chunks(&self) -> usize { @@ -393,7 +393,7 @@ macro_rules! match_dtype_to_logical_apply_macro { /// Apply a macro on the Downcasted ChunkedArray's #[macro_export] -macro_rules! match_arrow_data_type_apply_macro_ca { +macro_rules! match_arrow_dtype_apply_macro_ca { ($self:expr, $macro:ident, $macro_string:ident, $macro_bool:ident $(, $opt_args:expr)*) => {{ match $self.dtype() { DataType::String => $macro_string!($self.str().unwrap() $(, $opt_args)*), @@ -684,7 +684,7 @@ macro_rules! apply_method_physical_numeric { macro_rules! df { ($($col_name:expr => $slice:expr), + $(,)?) => { $crate::prelude::DataFrame::new(vec![ - $(<$crate::prelude::Series as $crate::prelude::NamedFrom::<_, _>>::new($col_name, $slice),)+ + $(<$crate::prelude::Series as $crate::prelude::NamedFrom::<_, _>>::new($col_name.into(), $slice),)+ ]) } } @@ -744,6 +744,7 @@ where for df in iter { acc_df.vstack_mut(&df)?; } + Ok(acc_df) } @@ -847,6 +848,16 @@ where pub(crate) fn align_chunks_binary_owned_series(left: Series, right: Series) -> (Series, Series) { match (left.chunks().len(), right.chunks().len()) { (1, 1) => (left, right), + // All chunks are equal length + (a, b) + if a == b + && left + .chunk_lengths() + .zip(right.chunk_lengths()) + .all(|(l, r)| l == r) => + { + (left, right) + }, (_, 1) => (left.rechunk(), right), (1, _) => (left, right.rechunk()), (_, _) => (left.rechunk(), right.rechunk()), @@ -863,6 +874,16 @@ where { match (left.chunks.len(), right.chunks.len()) { (1, 1) => (left, right), + // All chunks are equal length + (a, b) + if a == b + && left + .chunk_lengths() + .zip(right.chunk_lengths()) + .all(|(l, r)| l == r) => + { + (left, right) + }, (_, 1) => (left.rechunk(), right), (1, _) => (left, right.rechunk()), (_, _) => (left.rechunk(), right.rechunk()), @@ -974,42 +995,18 @@ where combine_validities_and(left_validity.as_ref(), right_validity.as_ref()) } +/// Convenience for `x.into_iter().map(Into::into).collect()` using an `into_vec()` function. pub trait IntoVec { fn into_vec(self) -> Vec; } -pub trait Arg {} -impl Arg for bool {} - -impl IntoVec for bool { - fn into_vec(self) -> Vec { - vec![self] - } -} - -impl IntoVec for Vec { - fn into_vec(self) -> Self { - self - } -} - -impl IntoVec for I -where - I: IntoIterator, - S: AsRef, -{ - fn into_vec(self) -> Vec { - self.into_iter().map(|s| s.as_ref().to_string()).collect() - } -} - -impl IntoVec for I +impl IntoVec for I where I: IntoIterator, - S: AsRef, + S: Into, { - fn into_vec(self) -> Vec { - self.into_iter().map(|s| s.as_ref().into()).collect() + fn into_vec(self) -> Vec { + self.into_iter().map(|s| s.into()).collect() } } @@ -1160,13 +1157,29 @@ pub fn coalesce_nulls_series(a: &Series, b: &Series) -> (Series, Series) { } } +pub fn operation_exceeded_idxsize_msg(operation: &str) -> String { + if core::mem::size_of::() == core::mem::size_of::() { + format!( + "{} exceeded the maximum supported limit of {} rows. Consider installing 'polars-u64-idx'.", + operation, + IdxSize::MAX, + ) + } else { + format!( + "{} exceeded the maximum supported limit of {} rows.", + operation, + IdxSize::MAX, + ) + } +} + #[cfg(test)] mod test { use super::*; #[test] fn test_split() { - let ca: Int32Chunked = (0..10).collect_ca("a"); + let ca: Int32Chunked = (0..10).collect_ca("a".into()); let out = split(&ca, 3); assert_eq!(out[0].len(), 3); @@ -1175,28 +1188,30 @@ mod test { } #[test] - fn test_align_chunks() { - let a = Int32Chunked::new("", &[1, 2, 3, 4]); - let mut b = Int32Chunked::new("", &[1]); - let b2 = Int32Chunked::new("", &[2, 3, 4]); + fn test_align_chunks() -> PolarsResult<()> { + let a = Int32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3, 4]); + let mut b = Int32Chunked::new(PlSmallStr::EMPTY, &[1]); + let b2 = Int32Chunked::new(PlSmallStr::EMPTY, &[2, 3, 4]); - b.append(&b2); + b.append(&b2)?; let (a, b) = align_chunks_binary(&a, &b); assert_eq!( a.chunk_lengths().collect::>(), b.chunk_lengths().collect::>() ); - let a = Int32Chunked::new("", &[1, 2, 3, 4]); - let mut b = Int32Chunked::new("", &[1]); + let a = Int32Chunked::new(PlSmallStr::EMPTY, &[1, 2, 3, 4]); + let mut b = Int32Chunked::new(PlSmallStr::EMPTY, &[1]); let b1 = b.clone(); - b.append(&b1); - b.append(&b1); - b.append(&b1); + b.append(&b1)?; + b.append(&b1)?; + b.append(&b1)?; let (a, b) = align_chunks_binary(&a, &b); assert_eq!( a.chunk_lengths().collect::>(), b.chunk_lengths().collect::>() ); + + Ok(()) } } diff --git a/crates/polars-core/src/utils/schema.rs b/crates/polars-core/src/utils/schema.rs index c528f3160624..558a0ea8f1b8 100644 --- a/crates/polars-core/src/utils/schema.rs +++ b/crates/polars-core/src/utils/schema.rs @@ -1,3 +1,5 @@ +use polars_utils::format_pl_smallstr; + use crate::prelude::*; /// Convert a collection of [`DataType`] into a schema. @@ -12,6 +14,6 @@ where dtypes .into_iter() .enumerate() - .map(|(i, dtype)| Field::new(format!("column_{i}").as_ref(), dtype)) + .map(|(i, dtype)| Field::new(format_pl_smallstr!("column_{i}"), dtype)) .collect() } diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs index fb9d674100e1..68f138557072 100644 --- a/crates/polars-core/src/utils/series.rs +++ b/crates/polars-core/src/utils/series.rs @@ -9,7 +9,7 @@ pub fn with_unstable_series(dtype: &DataType, f: F) -> T where F: Fn(&mut AmortSeries) -> T, { - let container = Series::full_null("", 0, dtype); + let container = Series::full_null(PlSmallStr::EMPTY, 0, dtype); let mut us = AmortSeries::new(Rc::new(container)); f(&mut us) diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 2068f228411b..027e85886793 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -1,3 +1,4 @@ +use bitflags::bitflags; use num_traits::Signed; use super::*; @@ -11,9 +12,43 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult { ) } +bitflags! { + #[repr(transparent)] + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] + pub struct SuperTypeFlags: u8 { + /// Implode lists to match nesting types. + const ALLOW_IMPLODE_LIST = 1 << 0; + /// Allow casting of primitive types (numeric, bools) to strings + const ALLOW_PRIMITIVE_TO_STRING = 1 << 1; + } +} + +impl Default for SuperTypeFlags { + fn default() -> Self { + SuperTypeFlags::from_bits_truncate(0) | SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING + } +} + #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Default)] pub struct SuperTypeOptions { - pub implode_list: bool, + pub flags: SuperTypeFlags, +} + +impl From for SuperTypeOptions { + fn from(flags: SuperTypeFlags) -> Self { + SuperTypeOptions { flags } + } +} + +impl SuperTypeOptions { + pub fn allow_implode_list(&self) -> bool { + self.flags.contains(SuperTypeFlags::ALLOW_IMPLODE_LIST) + } + + pub fn allow_primitive_to_string(&self) -> bool { + self.flags + .contains(SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING) + } } pub fn get_supertype(l: &DataType, r: &DataType) -> Option { @@ -209,11 +244,9 @@ pub fn get_supertype_with_options( #[cfg(feature = "dtype-time")] (Time, Float64) => Some(Float64), - // every known type can be casted to a string except binary - (dt, String) if !matches!(dt, DataType::Unknown(UnknownKind::Any)) && dt != &DataType::Binary => Some(String), - - (dt, String) if !matches!(dt, DataType::Unknown(UnknownKind::Any)) => Some(String), - + // Every known type can be cast to a string except binary + (dt, String) if !matches!(dt, Unknown(UnknownKind::Any)) && dt != &Binary && options.allow_primitive_to_string() || !dt.to_physical().is_primitive() => Some(String), + (String, Binary) => Some(Binary), (dt, Null) => Some(dt.clone()), #[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))] @@ -258,7 +291,7 @@ pub fn get_supertype_with_options( let st = get_supertype(inner_left, inner_right)?; Some(Array(Box::new(st), *width_left)) } - (List(inner), other) | (other, List(inner)) if options.implode_list => { + (List(inner), other) | (other, List(inner)) if options.allow_implode_list() => { let st = get_supertype(inner, other)?; Some(List(Box::new(st))) } @@ -276,8 +309,15 @@ pub fn get_supertype_with_options( }, (dt, Unknown(kind)) => { match kind { + UnknownKind::Float | UnknownKind::Int(_) if dt.is_string() => { + if options.allow_primitive_to_string() { + Some(dt.clone()) + } else { + None + } + }, // numeric vs float|str -> always float|str|decimal - UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() | dt.is_string() | dt.is_decimal() => Some(dt.clone()), + UnknownKind::Float | UnknownKind::Int(_) if dt.is_float() | dt.is_decimal() => Some(dt.clone()), UnknownKind::Float if dt.is_integer() => Some(Unknown(UnknownKind::Float)), // Materialize float to float or decimal UnknownKind::Float if dt.is_float() | dt.is_decimal() => Some(dt.clone()), @@ -329,7 +369,7 @@ pub fn get_supertype_with_options( let mut new_fields = Vec::with_capacity(fields_a.len()); for a in fields_a { let st = get_supertype(&a.dtype, rhs)?; - new_fields.push(Field::new(&a.name, st)) + new_fields.push(Field::new(a.name.clone(), st)) } Some(Struct(new_fields)) } @@ -386,7 +426,7 @@ fn union_struct_fields(fields_a: &[Field], fields_b: &[Field]) -> Option>(); Some(DataType::Struct(new_fields)) } @@ -402,7 +442,7 @@ fn super_type_structs(fields_a: &[Field], fields_b: &[Field]) -> Option = LazyLock::new(|| { + if env::var("POLARS_PANIC_ON_ERR").as_deref() == Ok("1") { + ErrorStrategy::Panic + } else if env::var("POLARS_BACKTRACE_IN_ERR").as_deref() == Ok("1") { + ErrorStrategy::WithBacktrace + } else { + ErrorStrategy::Normal + } +}); + #[derive(Debug)] pub struct ErrString(Cow<'static, str>); +impl ErrString { + pub const fn new_static(s: &'static str) -> Self { + Self(Cow::Borrowed(s)) + } +} + impl From for ErrString where T: Into>, { fn from(msg: T) -> Self { - if env::var("POLARS_PANIC_ON_ERR").as_deref().unwrap_or("") == "1" { - panic!("{}", msg.into()) - } else { - ErrString(msg.into()) + match &*ERROR_STRATEGY { + ErrorStrategy::Panic => panic!("{}", msg.into()), + ErrorStrategy::WithBacktrace => ErrString(Cow::Owned(format!( + "{}\n\nRust backtrace:\n{}", + msg.into(), + std::backtrace::Backtrace::force_capture() + ))), + ErrorStrategy::Normal => ErrString(msg.into()), } } } @@ -178,7 +204,7 @@ impl PolarsError { } } - fn wrap_msg String>(&self, func: F) -> Self { + pub fn wrap_msg String>(&self, func: F) -> Self { use PolarsError::*; match self { ColumnNotFound(msg) => ColumnNotFound(func(msg).into()), diff --git a/crates/polars-expr/Cargo.toml b/crates/polars-expr/Cargo.toml index 72b3a6aaeb66..7a1f974b41ff 100644 --- a/crates/polars-expr/Cargo.toml +++ b/crates/polars-expr/Cargo.toml @@ -13,6 +13,7 @@ ahash = { workspace = true } arrow = { workspace = true } bitflags = { workspace = true } once_cell = { workspace = true } +polars-compute = { workspace = true } polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } @@ -21,7 +22,6 @@ polars-plan = { workspace = true } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } rayon = { workspace = true } -smartstring = { workspace = true } [features] nightly = ["polars-core/nightly", "polars-plan/nightly"] diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index db45e2b79f6b..5c64ef144cc4 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -69,16 +69,16 @@ impl PhysicalExpr for AggregationExpr { GroupByMethod::Min => { if MetadataEnv::experimental_enabled() { if let Some(sc) = s.get_metadata().and_then(|v| v.min_value()) { - return Ok(sc.into_series(s.name())); + return Ok(sc.into_series(s.name().clone())); } } match s.is_sorted_flag() { IsSorted::Ascending | IsSorted::Descending => { - s.min_reduce().map(|sc| sc.into_series(s.name())) + s.min_reduce().map(|sc| sc.into_series(s.name().clone())) }, IsSorted::Not => parallel_op_series( - |s| s.min_reduce().map(|sc| sc.into_series(s.name())), + |s| s.min_reduce().map(|sc| sc.into_series(s.name().clone())), s, allow_threading, ), @@ -89,7 +89,7 @@ impl PhysicalExpr for AggregationExpr { |s| { Ok(polars_ops::prelude::nan_propagating_aggregate::nan_min_s( &s, - s.name(), + s.name().clone(), )) }, s, @@ -102,16 +102,16 @@ impl PhysicalExpr for AggregationExpr { GroupByMethod::Max => { if MetadataEnv::experimental_enabled() { if let Some(sc) = s.get_metadata().and_then(|v| v.max_value()) { - return Ok(sc.into_series(s.name())); + return Ok(sc.into_series(s.name().clone())); } } match s.is_sorted_flag() { IsSorted::Ascending | IsSorted::Descending => { - s.max_reduce().map(|sc| sc.into_series(s.name())) + s.max_reduce().map(|sc| sc.into_series(s.name().clone())) }, IsSorted::Not => parallel_op_series( - |s| s.max_reduce().map(|sc| sc.into_series(s.name())), + |s| s.max_reduce().map(|sc| sc.into_series(s.name().clone())), s, allow_threading, ), @@ -122,7 +122,7 @@ impl PhysicalExpr for AggregationExpr { |s| { Ok(polars_ops::prelude::nan_propagating_aggregate::nan_max_s( &s, - s.name(), + s.name().clone(), )) }, s, @@ -132,20 +132,20 @@ impl PhysicalExpr for AggregationExpr { GroupByMethod::NanMax => { panic!("activate 'propagate_nans' feature") }, - GroupByMethod::Median => s.median_reduce().map(|sc| sc.into_series(s.name())), - GroupByMethod::Mean => Ok(s.mean_reduce().into_series(s.name())), + GroupByMethod::Median => s.median_reduce().map(|sc| sc.into_series(s.name().clone())), + GroupByMethod::Mean => Ok(s.mean_reduce().into_series(s.name().clone())), GroupByMethod::First => Ok(if s.is_empty() { - Series::full_null(s.name(), 1, s.dtype()) + Series::full_null(s.name().clone(), 1, s.dtype()) } else { s.head(Some(1)) }), GroupByMethod::Last => Ok(if s.is_empty() { - Series::full_null(s.name(), 1, s.dtype()) + Series::full_null(s.name().clone(), 1, s.dtype()) } else { s.tail(Some(1)) }), GroupByMethod::Sum => parallel_op_series( - |s| s.sum_reduce().map(|sc| sc.into_series(s.name())), + |s| s.sum_reduce().map(|sc| sc.into_series(s.name().clone())), s, allow_threading, ), @@ -154,21 +154,26 @@ impl PhysicalExpr for AggregationExpr { if MetadataEnv::experimental_enabled() { if let Some(count) = s.get_metadata().and_then(|v| v.distinct_count()) { let count = count + IdxSize::from(s.null_count() > 0); - return Ok(IdxCa::from_slice(s.name(), &[count]).into_series()); + return Ok(IdxCa::from_slice(s.name().clone(), &[count]).into_series()); } } - s.n_unique() - .map(|count| IdxCa::from_slice(s.name(), &[count as IdxSize]).into_series()) + s.n_unique().map(|count| { + IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_series() + }) }, GroupByMethod::Count { include_nulls } => { let count = s.len() - s.null_count() * !include_nulls as usize; - Ok(IdxCa::from_slice(s.name(), &[count as IdxSize]).into_series()) + Ok(IdxCa::from_slice(s.name().clone(), &[count as IdxSize]).into_series()) }, GroupByMethod::Implode => s.implode().map(|ca| ca.into_series()), - GroupByMethod::Std(ddof) => s.std_reduce(ddof).map(|sc| sc.into_series(s.name())), - GroupByMethod::Var(ddof) => s.var_reduce(ddof).map(|sc| sc.into_series(s.name())), + GroupByMethod::Std(ddof) => s + .std_reduce(ddof) + .map(|sc| sc.into_series(s.name().clone())), + GroupByMethod::Var(ddof) => s + .var_reduce(ddof) + .map(|sc| sc.into_series(s.name().clone())), GroupByMethod::Quantile(_, _) => unimplemented!(), } } @@ -181,7 +186,7 @@ impl PhysicalExpr for AggregationExpr { ) -> PolarsResult> { let mut ac = self.input.evaluate_on_groups(df, groups, state)?; // don't change names by aggregations as is done in polars-core - let keep_name = ac.series().name().to_string(); + let keep_name = ac.series().name().clone(); polars_ensure!(!matches!(ac.agg_state(), AggState::Literal(_)), ComputeError: "cannot aggregate a literal"); if let AggregatedScalar(_) = ac.agg_state() { @@ -200,27 +205,27 @@ impl PhysicalExpr for AggregationExpr { GroupByMethod::Min => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_min(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Max => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_max(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Median => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_median(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Mean => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_mean(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Sum => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_sum(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Count { include_nulls } => { if include_nulls || ac.series().null_count() == 0 { @@ -262,7 +267,7 @@ impl PhysicalExpr for AggregationExpr { counts.into_inner() }, }; - s.rename(&keep_name); + s.rename(keep_name); AggregatedScalar(s.into_series()) }, UpdateGroups::WithGroupsLen => { @@ -270,13 +275,13 @@ impl PhysicalExpr for AggregationExpr { // we can just get the attribute, because we only need the length, // not the correct order let mut ca = ac.groups.group_count(); - ca.rename(&keep_name); + ca.rename(keep_name); AggregatedScalar(ca.into_series()) }, // materialize groups _ => { let mut ca = ac.groups().group_count(); - ca.rename(&keep_name); + ca.rename(keep_name); AggregatedScalar(ca.into_series()) }, } @@ -285,7 +290,7 @@ impl PhysicalExpr for AggregationExpr { match ac.agg_state() { AggState::Literal(s) | AggState::AggregatedScalar(s) => { AggregatedScalar(Series::new( - &keep_name, + keep_name, [(s.len() as IdxSize - s.null_count() as IdxSize)], )) }, @@ -298,13 +303,13 @@ impl PhysicalExpr for AggregationExpr { .map(|s| s.len() as IdxSize - s.null_count() as IdxSize) }) .collect(); - AggregatedScalar(rename_series(out.into_series(), &keep_name)) + AggregatedScalar(rename_series(out.into_series(), keep_name)) }, AggState::NotAggregated(s) => { let s = s.clone(); let groups = ac.groups(); let out: IdxCa = if matches!(s.dtype(), &DataType::Null) { - IdxCa::full(s.name(), 0, groups.len()) + IdxCa::full(s.name().clone(), 0, groups.len()) } else { match groups.as_ref() { GroupsProxy::Idx(idx) => { @@ -322,9 +327,7 @@ impl PhysicalExpr for AggregationExpr { }); count }) - .collect_ca_trusted_with_dtype( - &keep_name, IDX_DTYPE, - ) + .collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE) }, GroupsProxy::Slice { groups, .. } => { // Slice and use computed null count @@ -338,9 +341,7 @@ impl PhysicalExpr for AggregationExpr { .null_count() as IdxSize }) - .collect_ca_trusted_with_dtype( - &keep_name, IDX_DTYPE, - ) + .collect_ca_trusted_with_dtype(keep_name, IDX_DTYPE) }, } }; @@ -352,17 +353,17 @@ impl PhysicalExpr for AggregationExpr { GroupByMethod::First => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_first(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Last => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_last(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::NUnique => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_n_unique(&groups); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Implode => { // if the aggregation is already @@ -380,22 +381,22 @@ impl PhysicalExpr for AggregationExpr { agg.as_list().into_series() }, }; - AggregatedList(rename_series(s, &keep_name)) + AggregatedList(rename_series(s, keep_name)) }, GroupByMethod::Groups => { let mut column: ListChunked = ac.groups().as_list_chunked(); - column.rename(&keep_name); + column.rename(keep_name); AggregatedScalar(column.into_series()) }, GroupByMethod::Std(ddof) => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_std(&groups, ddof); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Var(ddof) => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_var(&groups, ddof); - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) }, GroupByMethod::Quantile(_, _) => { // implemented explicitly in AggQuantile struct @@ -410,7 +411,7 @@ impl PhysicalExpr for AggregationExpr { } else { s.agg_min(&groups) }; - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) } #[cfg(not(feature = "propagate_nans"))] { @@ -426,7 +427,7 @@ impl PhysicalExpr for AggregationExpr { } else { s.agg_max(&groups) }; - AggregatedScalar(rename_series(agg_s, &keep_name)) + AggregatedScalar(rename_series(agg_s, keep_name)) } #[cfg(not(feature = "propagate_nans"))] { @@ -455,7 +456,7 @@ impl PhysicalExpr for AggregationExpr { } } -fn rename_series(mut s: Series, name: &str) -> Series { +fn rename_series(mut s: Series, name: PlSmallStr) -> Series { s.rename(name); s } @@ -476,7 +477,7 @@ impl PartitionedAggregation for AggregationExpr { match self.agg_type.groupby { #[cfg(feature = "dtype-struct")] GroupByMethod::Mean => { - let new_name = series.name().to_string(); + let new_name = series.name().clone(); // ensure we don't overflow // the all 8 and 16 bits integers are already upcasted to int16 on `agg_sum` @@ -486,7 +487,7 @@ impl PartitionedAggregation for AggregationExpr { } else { series.agg_sum(groups) }; - agg_s.rename(&new_name); + agg_s.rename(new_name.clone()); if !agg_s.dtype().is_numeric() { Ok(agg_s) @@ -496,48 +497,48 @@ impl PartitionedAggregation for AggregationExpr { _ => agg_s.cast(&DataType::Float64).unwrap(), }; let mut count_s = series.agg_valid_count(groups); - count_s.rename("__POLARS_COUNT"); - Ok(StructChunked::from_series(&new_name, &[agg_s, count_s]) + count_s.rename(PlSmallStr::from_static("__POLARS_COUNT")); + Ok(StructChunked::from_series(new_name, &[agg_s, count_s]) .unwrap() .into_series()) } }, GroupByMethod::Implode => { - let new_name = series.name(); + let new_name = series.name().clone(); let mut agg = series.agg_list(groups); agg.rename(new_name); Ok(agg) }, GroupByMethod::First => { let mut agg = series.agg_first(groups); - agg.rename(series.name()); + agg.rename(series.name().clone()); Ok(agg) }, GroupByMethod::Last => { let mut agg = series.agg_last(groups); - agg.rename(series.name()); + agg.rename(series.name().clone()); Ok(agg) }, GroupByMethod::Max => { let mut agg = series.agg_max(groups); - agg.rename(series.name()); + agg.rename(series.name().clone()); Ok(agg) }, GroupByMethod::Min => { let mut agg = series.agg_min(groups); - agg.rename(series.name()); + agg.rename(series.name().clone()); Ok(agg) }, GroupByMethod::Sum => { let mut agg = series.agg_sum(groups); - agg.rename(series.name()); + agg.rename(series.name().clone()); Ok(agg) }, GroupByMethod::Count { include_nulls: true, } => { let mut ca = groups.group_count(); - ca.rename(series.name()); + ca.rename(series.name().clone()); Ok(ca.into_series()) }, _ => { @@ -559,12 +560,12 @@ impl PartitionedAggregation for AggregationExpr { } | GroupByMethod::Sum => { let mut agg = unsafe { partitioned.agg_sum(groups) }; - agg.rename(partitioned.name()); + agg.rename(partitioned.name().clone()); Ok(agg) }, #[cfg(feature = "dtype-struct")] GroupByMethod::Mean => { - let new_name = partitioned.name(); + let new_name = partitioned.name().clone(); match partitioned.dtype() { DataType::Struct(_) => { let ca = partitioned.struct_().unwrap(); @@ -587,7 +588,7 @@ impl PartitionedAggregation for AggregationExpr { // the groups are scattered over multiple groups/sub dataframes. // we now must collect them into a single group let ca = partitioned.list().unwrap(); - let new_name = partitioned.name().to_string(); + let new_name = partitioned.name().clone(); let mut values = Vec::with_capacity(groups.len()); let mut can_fast_explode = true; @@ -631,15 +632,15 @@ impl PartitionedAggregation for AggregationExpr { let vals = values.iter().map(|arr| &**arr).collect::>(); let values = concatenate(&vals).unwrap(); - let data_type = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); // SAFETY: offsets are monotonically increasing. let arr = ListArray::::new( - data_type, + dtype, unsafe { Offsets::new_unchecked(offsets).into() }, values, None, ); - let mut ca = ListChunked::with_chunk(&new_name, arr); + let mut ca = ListChunked::with_chunk(new_name, arr); if can_fast_explode { ca.set_fast_explode() } @@ -647,22 +648,22 @@ impl PartitionedAggregation for AggregationExpr { }, GroupByMethod::First => { let mut agg = unsafe { partitioned.agg_first(groups) }; - agg.rename(partitioned.name()); + agg.rename(partitioned.name().clone()); Ok(agg) }, GroupByMethod::Last => { let mut agg = unsafe { partitioned.agg_last(groups) }; - agg.rename(partitioned.name()); + agg.rename(partitioned.name().clone()); Ok(agg) }, GroupByMethod::Max => { let mut agg = unsafe { partitioned.agg_max(groups) }; - agg.rename(partitioned.name()); + agg.rename(partitioned.name().clone()); Ok(agg) }, GroupByMethod::Min => { let mut agg = unsafe { partitioned.agg_min(groups) }; - agg.rename(partitioned.name()); + agg.rename(partitioned.name().clone()); Ok(agg) }, _ => unimplemented!(), @@ -709,7 +710,7 @@ impl PhysicalExpr for AggQuantileExpr { let quantile = self.get_quantile(df, state)?; input .quantile_reduce(quantile, self.interpol) - .map(|sc| sc.into_series(input.name())) + .map(|sc| sc.into_series(input.name().clone())) } #[allow(clippy::ptr_arg)] fn evaluate_on_groups<'a>( @@ -720,7 +721,7 @@ impl PhysicalExpr for AggQuantileExpr { ) -> PolarsResult> { let mut ac = self.input.evaluate_on_groups(df, groups, state)?; // don't change names by aggregations as is done in polars-core - let keep_name = ac.series().name().to_string(); + let keep_name = ac.series().name().clone(); let quantile = self.get_quantile(df, state)?; @@ -731,7 +732,7 @@ impl PhysicalExpr for AggQuantileExpr { .into_owned() .agg_quantile(ac.groups(), quantile, self.interpol) }; - agg.rename(&keep_name); + agg.rename(keep_name); Ok(AggregationContext::from_agg_state( AggregatedScalar(agg), Cow::Borrowed(groups), diff --git a/crates/polars-expr/src/expressions/alias.rs b/crates/polars-expr/src/expressions/alias.rs index fa755fd2b233..a6ea8953288c 100644 --- a/crates/polars-expr/src/expressions/alias.rs +++ b/crates/polars-expr/src/expressions/alias.rs @@ -5,12 +5,12 @@ use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExp pub struct AliasExpr { pub(crate) physical_expr: Arc, - pub(crate) name: Arc, + pub(crate) name: PlSmallStr, expr: Expr, } impl AliasExpr { - pub fn new(physical_expr: Arc, name: Arc, expr: Expr) -> Self { + pub fn new(physical_expr: Arc, name: PlSmallStr, expr: Expr) -> Self { Self { physical_expr, name, @@ -19,7 +19,7 @@ impl AliasExpr { } fn finish(&self, input: Series) -> Series { - input.with_name(&self.name) + input.with_name(self.name.clone()) } } @@ -54,11 +54,8 @@ impl PhysicalExpr for AliasExpr { fn to_field(&self, input_schema: &Schema) -> PolarsResult { Ok(Field::new( - &self.name, - self.physical_expr - .to_field(input_schema)? - .data_type() - .clone(), + self.name.clone(), + self.physical_expr.to_field(input_schema)?.dtype().clone(), )) } @@ -76,7 +73,7 @@ impl PartitionedAggregation for AliasExpr { ) -> PolarsResult { let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); let s = agg.evaluate_partitioned(df, groups, state)?; - Ok(s.with_name(&self.name)) + Ok(s.with_name(self.name.clone())) } fn finalize( @@ -87,6 +84,6 @@ impl PartitionedAggregation for AliasExpr { ) -> PolarsResult { let agg = self.physical_expr.as_partitioned_aggregator().unwrap(); let s = agg.finalize(partitioned, groups, state)?; - Ok(s.with_name(&self.name)) + Ok(s.with_name(self.name.clone())) } } diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 4d13d784540e..2b2bbc1e57d2 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -130,7 +130,7 @@ impl ApplyExpr { Ok(out) } else { let field = self.to_field(self.input_schema.as_ref().unwrap()).unwrap(); - Ok(Series::full_null(field.name(), 1, field.data_type())) + Ok(Series::full_null(field.name().clone(), 1, field.dtype())) } } fn apply_single_group_aware<'a>( @@ -145,17 +145,17 @@ impl ApplyExpr { ComputeError: "cannot aggregate, the column is already aggregated", ); - let name = s.name().to_string(); + let name = s.name().clone(); let agg = ac.aggregated(); // Collection of empty list leads to a null dtype. See: #3687. if agg.len() == 0 { // Create input for the function to determine the output dtype, see #3946. let agg = agg.list().unwrap(); let input_dtype = agg.inner_dtype(); - let input = Series::full_null("", 0, input_dtype); + let input = Series::full_null(PlSmallStr::EMPTY, 0, input_dtype); let output = self.eval_and_flatten(&mut [input])?; - let ca = ListChunked::full(&name, &output, 0); + let ca = ListChunked::full(name, &output, 0); return self.finish_apply_groups(ac, ca); } @@ -163,7 +163,7 @@ impl ApplyExpr { None => Ok(None), Some(mut s) => { if self.pass_name_to_apply { - s.rename(&name); + s.rename(name.clone()); } self.function.call_udf(&mut [s]) }, @@ -181,7 +181,7 @@ impl ApplyExpr { if let Some(dtype) = dtype { // TODO! uncomment this line and remove debug_assertion after a while. // POOL.install(|| { - // iter.collect_ca_with_dtype::>("", DataType::List(Box::new(dtype))) + // iter.collect_ca_with_dtype::>(PlSmallStr::EMPTY, DataType::List(Box::new(dtype))) // })? let out: ListChunked = POOL.install(|| iter.collect::>())?; @@ -199,7 +199,7 @@ impl ApplyExpr { .collect::>()? }; - self.finish_apply_groups(ac, ca.with_name(&name)) + self.finish_apply_groups(ac, ca.with_name(name)) } /// Apply elementwise e.g. ignore the group/list indices. @@ -240,10 +240,7 @@ impl ApplyExpr { // then unpack the lists and finally create iterators from this list chunked arrays. let mut iters = acs .iter_mut() - .map(|ac| { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { ac.iter_groups(self.pass_name_to_apply) } - }) + .map(|ac| ac.iter_groups(self.pass_name_to_apply)) .collect::>(); // Length of the items to iterate over. @@ -257,14 +254,14 @@ impl ApplyExpr { ac.with_update_groups(UpdateGroups::No); let agg_state = if self.returns_scalar { - AggState::AggregatedScalar(Series::new_empty(field.name(), &field.dtype)) + AggState::AggregatedScalar(Series::new_empty(field.name().clone(), &field.dtype)) } else { match self.collect_groups { ApplyOptions::ElementWise | ApplyOptions::ApplyList => ac .agg_state() - .map(|_| Series::new_empty(field.name(), &field.dtype)), + .map(|_| Series::new_empty(field.name().clone(), &field.dtype)), ApplyOptions::GroupWise => AggState::AggregatedList(Series::new_empty( - field.name(), + field.name().clone(), &DataType::List(Box::new(field.dtype.clone())), )), } @@ -286,7 +283,7 @@ impl ApplyExpr { self.function.call_udf(&mut container) }) .collect::>()? - .with_name(&field.name); + .with_name(field.name.clone()); drop(iters); @@ -333,8 +330,8 @@ impl PhysicalExpr for ApplyExpr { if self.allow_rename { self.eval_and_flatten(&mut inputs) } else { - let in_name = inputs[0].name().to_string(); - Ok(self.eval_and_flatten(&mut inputs)?.with_name(&in_name)) + let in_name = inputs[0].name().clone(); + Ok(self.eval_and_flatten(&mut inputs)?.with_name(in_name)) } } @@ -580,7 +577,7 @@ impl ApplyExpr { #[cfg(feature = "is_between")] FunctionExpr::Boolean(BooleanFunction::IsBetween { closed }) => { let should_read = || -> Option { - let root: Arc = expr_to_leaf_column_name(&input[0]).ok()?; + let root: PlSmallStr = expr_to_leaf_column_name(&input[0]).ok()?; let Expr::Literal(left) = &input[1] else { return None; }; @@ -595,11 +592,20 @@ impl ApplyExpr { let (left, left_dtype) = (left.to_any_value()?, left.get_datatype()); let (right, right_dtype) = (right.to_any_value()?, right.get_datatype()); - let left = - Series::from_any_values_and_dtype("", &[left], &left_dtype, false).ok()?; - let right = - Series::from_any_values_and_dtype("", &[right], &right_dtype, false) - .ok()?; + let left = Series::from_any_values_and_dtype( + PlSmallStr::EMPTY, + &[left], + &left_dtype, + false, + ) + .ok()?; + let right = Series::from_any_values_and_dtype( + PlSmallStr::EMPTY, + &[right], + &right_dtype, + false, + ) + .ok()?; // don't read the row_group anyways as // the condition will evaluate to false. @@ -652,8 +658,8 @@ impl PartitionedAggregation for ApplyExpr { if self.allow_rename { self.eval_and_flatten(&mut [s]) } else { - let in_name = s.name().to_string(); - Ok(self.eval_and_flatten(&mut [s])?.with_name(&in_name)) + let in_name = s.name().clone(); + Ok(self.eval_and_flatten(&mut [s])?.with_name(in_name)) } } diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 38d497b0b1bd..d18a1110ff7b 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -133,7 +133,7 @@ impl BinaryExpr { mut ac_l: AggregationContext<'a>, mut ac_r: AggregationContext<'a>, ) -> PolarsResult> { - let name = ac_l.series().name().to_string(); + let name = ac_l.series().name().clone(); ac_l.groups(); ac_r.groups(); polars_ensure!(ac_l.groups.len() == ac_r.groups.len(), ComputeError: "lhs and rhs should have same group length"); @@ -144,7 +144,7 @@ impl BinaryExpr { let res_s = if res_s.len() == 1 { res_s.new_from_index(0, ac_l.groups.len()) } else { - ListChunked::full(&name, &res_s, ac_l.groups.len()).into_series() + ListChunked::full(name, &res_s, ac_l.groups.len()).into_series() }; ac_l.with_series(res_s, true, Some(&self.expr))?; Ok(ac_l) @@ -155,16 +155,14 @@ impl BinaryExpr { mut ac_l: AggregationContext<'a>, mut ac_r: AggregationContext<'a>, ) -> PolarsResult> { - let name = ac_l.series().name().to_string(); - // SAFETY: unstable series never lives longer than the iterator. - let ca = unsafe { - ac_l.iter_groups(false) - .zip(ac_r.iter_groups(false)) - .map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op))) - .map(|opt_res| opt_res.transpose()) - .collect::>()? - .with_name(&name) - }; + let name = ac_l.series().name().clone(); + let ca = ac_l + .iter_groups(false) + .zip(ac_r.iter_groups(false)) + .map(|(l, r)| Some(apply_operator(l?.as_ref(), r?.as_ref(), self.op))) + .map(|opt_res| opt_res.transpose()) + .collect::>()? + .with_name(name); ac_l.with_update_groups(UpdateGroups::WithSeriesLen); ac_l.with_agg_state(AggState::AggregatedList(ca.into_series())); @@ -399,7 +397,7 @@ mod stats { #[cfg(debug_assertions)] { - match (fld_l.data_type(), fld_r.data_type()) { + match (fld_l.dtype(), fld_r.dtype()) { #[cfg(feature = "dtype-categorical")] (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => {}, #[cfg(feature = "dtype-categorical")] diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index f78f1f7c11f7..dcc3d4dddbca 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -6,14 +6,14 @@ use crate::expressions::{AggState, AggregationContext, PartitionedAggregation, P pub struct CastExpr { pub(crate) input: Arc, - pub(crate) data_type: DataType, + pub(crate) dtype: DataType, pub(crate) expr: Expr, pub(crate) options: CastOptions, } impl CastExpr { fn finish(&self, input: &Series) -> PolarsResult { - input.cast_with_options(&self.data_type, self.options) + input.cast_with_options(&self.dtype, self.options) } } @@ -71,7 +71,7 @@ impl PhysicalExpr for CastExpr { fn to_field(&self, input_schema: &Schema) -> PolarsResult { self.input.to_field(input_schema).map(|mut fld| { - fld.coerce(self.data_type.clone()); + fld.coerce(self.dtype.clone()); fld }) } diff --git a/crates/polars-expr/src/expressions/column.rs b/crates/polars-expr/src/expressions/column.rs index cac4b52ddb11..603353e4815b 100644 --- a/crates/polars-expr/src/expressions/column.rs +++ b/crates/polars-expr/src/expressions/column.rs @@ -7,13 +7,13 @@ use super::*; use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; pub struct ColumnExpr { - name: Arc, + name: PlSmallStr, expr: Expr, schema: Option, } impl ColumnExpr { - pub fn new(name: Arc, expr: Expr, schema: Option) -> Self { + pub fn new(name: PlSmallStr, expr: Expr, schema: Option) -> Self { Self { name, expr, schema } } } @@ -123,7 +123,7 @@ impl ColumnExpr { // Linear search will be relatively cheap as we only search the CSE columns. Ok(columns .iter() - .find(|s| s.name() == self.name.as_ref()) + .find(|s| s.name() == &self.name) .unwrap() .clone()) } diff --git a/crates/polars-expr/src/expressions/count.rs b/crates/polars-expr/src/expressions/count.rs index 246e939e3ef3..2d8fbeb6a2d2 100644 --- a/crates/polars-expr/src/expressions/count.rs +++ b/crates/polars-expr/src/expressions/count.rs @@ -22,7 +22,10 @@ impl PhysicalExpr for CountExpr { } fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult { - Ok(Series::new("len", [df.height() as IdxSize])) + Ok(Series::new( + PlSmallStr::from_static("len"), + [df.height() as IdxSize], + )) } fn evaluate_on_groups<'a>( @@ -31,13 +34,13 @@ impl PhysicalExpr for CountExpr { groups: &'a GroupsProxy, _state: &ExecutionState, ) -> PolarsResult> { - let ca = groups.group_count().with_name(LEN); + let ca = groups.group_count().with_name(PlSmallStr::from_static(LEN)); let s = ca.into_series(); Ok(AggregationContext::new(s, Cow::Borrowed(groups), true)) } fn to_field(&self, _input_schema: &Schema) -> PolarsResult { - Ok(Field::new(LEN, IDX_DTYPE)) + Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) } fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { @@ -67,6 +70,6 @@ impl PartitionedAggregation for CountExpr { ) -> PolarsResult { // SAFETY: groups are in bounds. let agg = unsafe { partitioned.agg_sum(groups) }; - Ok(agg.with_name(LEN)) + Ok(agg.with_name(PlSmallStr::from_static(LEN))) } } diff --git a/crates/polars-expr/src/expressions/filter.rs b/crates/polars-expr/src/expressions/filter.rs index d9df88419ae7..4e02b38ae4b7 100644 --- a/crates/polars-expr/src/expressions/filter.rs +++ b/crates/polars-expr/src/expressions/filter.rs @@ -45,15 +45,20 @@ impl PhysicalExpr for FilterExpr { let (ac_s, ac_predicate) = POOL.install(|| rayon::join(ac_s_f, ac_predicate_f)); let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?); + // Check if the groups are still equal, otherwise aggregate. + // TODO! create a special group iters that don't materialize + if ac_s.groups.as_ref() as *const _ != ac_predicate.groups.as_ref() as *const _ { + let _ = ac_s.aggregated(); + let _ = ac_predicate.aggregated(); + } if ac_predicate.is_aggregated() || ac_s.is_aggregated() { - // SAFETY: unstable series never lives longer than the iterator. - let preds = unsafe { ac_predicate.iter_groups(false) }; + let preds = ac_predicate.iter_groups(false); let s = ac_s.aggregated(); let ca = s.list()?; let out = if ca.is_empty() { // return an empty list if ca is empty. - ListChunked::full_null_with_dtype(ca.name(), 0, ca.inner_dtype()) + ListChunked::full_null_with_dtype(ca.name().clone(), 0, ca.inner_dtype()) } else { { ca.amortized_iter() @@ -65,7 +70,7 @@ impl PhysicalExpr for FilterExpr { _ => Ok(None), }) .collect::>()? - .with_name(s.name()) + .with_name(s.name().clone()) } }; ac_s.with_series(out.into_series(), true, Some(&self.expr))?; diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs index c54f8b9e8262..c82bedee986b 100644 --- a/crates/polars-expr/src/expressions/gather.rs +++ b/crates/polars-expr/src/expressions/gather.rs @@ -81,7 +81,7 @@ impl PhysicalExpr for GatherExpr { .map(|(s, idx)| Some(s?.as_ref().take(idx?.as_ref().idx().unwrap()))) .map(|opt_res| opt_res.transpose()) .collect::>()? - .with_name(ac.series().name()) + .with_name(ac.series().name().clone()) }; ac.with_series(taken.into_series(), true, Some(&self.expr))?; @@ -250,24 +250,22 @@ impl GatherExpr { &ac.dtype(), idx.series().len(), groups.len(), - ac.series().name(), + ac.series().name().clone(), )?; - unsafe { - let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); - for (s, idx) in iter { - match (s, idx) { - (Some(s), Some(idx)) => { - let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?; - let out = s.as_ref().take(&idx)?; - builder.append_series(&out)?; - }, - _ => builder.append_null(), - }; - } - let out = builder.finish().into_series(); - ac.with_agg_state(AggState::AggregatedList(out)); + let iter = ac.iter_groups(false).zip(idx.iter_groups(false)); + for (s, idx) in iter { + match (s, idx) { + (Some(s), Some(idx)) => { + let idx = convert_to_unsigned_index(idx.as_ref(), s.as_ref().len())?; + let out = s.as_ref().take(&idx)?; + builder.append_series(&out)?; + }, + _ => builder.append_null(), + }; } + let out = builder.finish().into_series(); + ac.with_agg_state(AggState::AggregatedList(out)); Ok(ac) } } diff --git a/crates/polars-expr/src/expressions/group_iter.rs b/crates/polars-expr/src/expressions/group_iter.rs index 8c921a519bd1..6b1d54d0ac13 100644 --- a/crates/polars-expr/src/expressions/group_iter.rs +++ b/crates/polars-expr/src/expressions/group_iter.rs @@ -5,10 +5,7 @@ use polars_core::series::amortized_iter::AmortSeries; use super::*; impl<'a> AggregationContext<'a> { - /// # Safety - /// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive - /// longer than the iterator is UB. - pub(super) unsafe fn iter_groups( + pub(super) fn iter_groups( &mut self, keep_names: bool, ) -> Box> + '_> { @@ -16,7 +13,11 @@ impl<'a> AggregationContext<'a> { AggState::Literal(_) => { self.groups(); let s = self.series().rechunk(); - let name = if keep_names { s.name() } else { "" }; + let name = if keep_names { + s.name().clone() + } else { + PlSmallStr::EMPTY + }; // SAFETY: dtype is correct unsafe { Box::new(LitIter::new( @@ -30,7 +31,11 @@ impl<'a> AggregationContext<'a> { AggState::AggregatedScalar(_) => { self.groups(); let s = self.series(); - let name = if keep_names { s.name() } else { "" }; + let name = if keep_names { + s.name().clone() + } else { + PlSmallStr::EMPTY + }; // SAFETY: dtype is correct unsafe { Box::new(FlatIter::new( @@ -44,7 +49,11 @@ impl<'a> AggregationContext<'a> { AggState::AggregatedList(_) => { let s = self.series(); let list = s.list().unwrap(); - let name = if keep_names { s.name() } else { "" }; + let name = if keep_names { + s.name().clone() + } else { + PlSmallStr::EMPTY + }; Box::new(list.amortized_iter_with_name(name)) }, AggState::NotAggregated(_) => { @@ -52,7 +61,11 @@ impl<'a> AggregationContext<'a> { let _ = self.aggregated(); let s = self.series(); let list = s.list().unwrap(); - let name = if keep_names { s.name() } else { "" }; + let name = if keep_names { + s.name().clone() + } else { + PlSmallStr::EMPTY + }; Box::new(list.amortized_iter_with_name(name)) }, } @@ -71,7 +84,7 @@ struct LitIter { impl LitIter { /// # Safety /// Caller must ensure the given `logical` dtype belongs to `array`. - unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: &str) -> Self { + unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: PlSmallStr) -> Self { let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked( name, vec![array], @@ -120,7 +133,7 @@ struct FlatIter { impl FlatIter { /// # Safety /// Caller must ensure the given `logical` dtype belongs to `array`. - unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: &str) -> Self { + unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: PlSmallStr) -> Self { let mut stack = Vec::with_capacity(chunks.len()); for chunk in chunks.iter().rev() { stack.push(chunk.clone()) diff --git a/crates/polars-expr/src/expressions/literal.rs b/crates/polars-expr/src/expressions/literal.rs index 6b43825087a1..6d406c5b297a 100644 --- a/crates/polars-expr/src/expressions/literal.rs +++ b/crates/polars-expr/src/expressions/literal.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use polars_core::prelude::*; use polars_core::utils::NoNull; -use polars_plan::constants::LITERAL_NAME; +use polars_plan::constants::get_literal_name; use super::*; use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; @@ -24,30 +24,26 @@ impl PhysicalExpr for LiteralExpr { use LiteralValue::*; let s = match &self.0 { #[cfg(feature = "dtype-i8")] - Int8(v) => Int8Chunked::full(LITERAL_NAME, *v, 1).into_series(), + Int8(v) => Int8Chunked::full(get_literal_name().clone(), *v, 1).into_series(), #[cfg(feature = "dtype-i16")] - Int16(v) => Int16Chunked::full(LITERAL_NAME, *v, 1).into_series(), - Int32(v) => Int32Chunked::full(LITERAL_NAME, *v, 1).into_series(), - Int64(v) => Int64Chunked::full(LITERAL_NAME, *v, 1).into_series(), + Int16(v) => Int16Chunked::full(get_literal_name().clone(), *v, 1).into_series(), + Int32(v) => Int32Chunked::full(get_literal_name().clone(), *v, 1).into_series(), + Int64(v) => Int64Chunked::full(get_literal_name().clone(), *v, 1).into_series(), #[cfg(feature = "dtype-u8")] - UInt8(v) => UInt8Chunked::full(LITERAL_NAME, *v, 1).into_series(), + UInt8(v) => UInt8Chunked::full(get_literal_name().clone(), *v, 1).into_series(), #[cfg(feature = "dtype-u16")] - UInt16(v) => UInt16Chunked::full(LITERAL_NAME, *v, 1).into_series(), - UInt32(v) => UInt32Chunked::full(LITERAL_NAME, *v, 1).into_series(), - UInt64(v) => UInt64Chunked::full(LITERAL_NAME, *v, 1).into_series(), - Float32(v) => Float32Chunked::full(LITERAL_NAME, *v, 1).into_series(), - Float64(v) => Float64Chunked::full(LITERAL_NAME, *v, 1).into_series(), + UInt16(v) => UInt16Chunked::full(get_literal_name().clone(), *v, 1).into_series(), + UInt32(v) => UInt32Chunked::full(get_literal_name().clone(), *v, 1).into_series(), + UInt64(v) => UInt64Chunked::full(get_literal_name().clone(), *v, 1).into_series(), + Float32(v) => Float32Chunked::full(get_literal_name().clone(), *v, 1).into_series(), + Float64(v) => Float64Chunked::full(get_literal_name().clone(), *v, 1).into_series(), #[cfg(feature = "dtype-decimal")] - Decimal(v, scale) => Int128Chunked::full(LITERAL_NAME, *v, 1) + Decimal(v, scale) => Int128Chunked::full(get_literal_name().clone(), *v, 1) .into_decimal_unchecked(None, *scale) .into_series(), - Boolean(v) => BooleanChunked::full(LITERAL_NAME, *v, 1).into_series(), - Null => polars_core::prelude::Series::new_null(LITERAL_NAME, 1), - Range { - low, - high, - data_type, - } => match data_type { + Boolean(v) => BooleanChunked::full(get_literal_name().clone(), *v, 1).into_series(), + Null => polars_core::prelude::Series::new_null(get_literal_name().clone(), 1), + Range { low, high, dtype } => match dtype { DataType::Int32 => { polars_ensure!( *low >= i32::MIN as i64 && *high <= i32::MAX as i64, @@ -78,27 +74,29 @@ impl PhysicalExpr for LiteralExpr { InvalidOperation: "datatype `{}` is not supported as range", dt ), }, - String(v) => StringChunked::full(LITERAL_NAME, v, 1).into_series(), - Binary(v) => BinaryChunked::full(LITERAL_NAME, v, 1).into_series(), + String(v) => StringChunked::full(get_literal_name().clone(), v, 1).into_series(), + Binary(v) => BinaryChunked::full(get_literal_name().clone(), v, 1).into_series(), #[cfg(feature = "dtype-datetime")] - DateTime(timestamp, tu, tz) => Int64Chunked::full(LITERAL_NAME, *timestamp, 1) - .into_datetime(*tu, tz.clone()) - .into_series(), + DateTime(timestamp, tu, tz) => { + Int64Chunked::full(get_literal_name().clone(), *timestamp, 1) + .into_datetime(*tu, tz.clone()) + .into_series() + }, #[cfg(feature = "dtype-duration")] - Duration(v, tu) => Int64Chunked::full(LITERAL_NAME, *v, 1) + Duration(v, tu) => Int64Chunked::full(get_literal_name().clone(), *v, 1) .into_duration(*tu) .into_series(), #[cfg(feature = "dtype-date")] - Date(v) => Int32Chunked::full(LITERAL_NAME, *v, 1) + Date(v) => Int32Chunked::full(get_literal_name().clone(), *v, 1) .into_date() .into_series(), #[cfg(feature = "dtype-time")] - Time(v) => Int64Chunked::full(LITERAL_NAME, *v, 1) + Time(v) => Int64Chunked::full(get_literal_name().clone(), *v, 1) .into_time() .into_series(), Series(series) => series.deref().clone(), lv @ (Int(_) | Float(_) | StrCat(_)) => polars_core::prelude::Series::from_any_values( - LITERAL_NAME, + get_literal_name().clone(), &[lv.to_any_value().unwrap()], false, ) @@ -124,7 +122,7 @@ impl PhysicalExpr for LiteralExpr { fn to_field(&self, _input_schema: &Schema) -> PolarsResult { let dtype = self.0.get_datatype(); - Ok(Field::new("literal", dtype)) + Ok(Field::new(PlSmallStr::from_static("literal"), dtype)) } fn is_literal(&self) -> bool { true diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 17179f89cbdd..b66920de7ab9 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -421,7 +421,9 @@ impl<'a> AggregationContext<'a> { self.groups(); let rows = self.groups.len(); let s = s.new_from_index(0, rows); - s.reshape_list(&[rows as i64, -1]).unwrap() + let out = s.reshape_list(&[rows as i64, -1]).unwrap(); + self.state = AggState::AggregatedList(out.clone()); + out }, } } @@ -613,7 +615,7 @@ impl PhysicalIoExpr for PhysicalIoHelper { self.expr.evaluate(df, &state) } - fn live_variables(&self) -> Option>> { + fn live_variables(&self) -> Option> { Some(expr_to_leaf_column_names(self.expr.as_expression()?)) } diff --git a/crates/polars-expr/src/expressions/rolling.rs b/crates/polars-expr/src/expressions/rolling.rs index 614673091f07..601901460c3f 100644 --- a/crates/polars-expr/src/expressions/rolling.rs +++ b/crates/polars-expr/src/expressions/rolling.rs @@ -13,7 +13,7 @@ pub(crate) struct RollingExpr { /// A function Expr. i.e. Mean, Median, Max, etc. pub(crate) function: Expr, pub(crate) phys_function: Arc, - pub(crate) out_name: Option>, + pub(crate) out_name: Option, pub(crate) options: RollingGroupOptions, pub(crate) expr: Expr, } @@ -45,7 +45,7 @@ impl PhysicalExpr for RollingExpr { .finalize(); polars_ensure!(out.len() == groups.len(), agg_len = out.len(), groups.len()); if let Some(name) = &self.out_name { - out.rename(name.as_ref()); + out.rename(name.clone()); } Ok(out) } diff --git a/crates/polars-expr/src/expressions/slice.rs b/crates/polars-expr/src/expressions/slice.rs index d2bc9137a7d3..579c8d66635e 100644 --- a/crates/polars-expr/src/expressions/slice.rs +++ b/crates/polars-expr/src/expressions/slice.rs @@ -60,9 +60,16 @@ fn check_argument(arg: &Series, groups: &GroupsProxy, name: &str, expr: &Expr) - Ok(()) } -fn slice_groups_idx(offset: i64, length: usize, first: IdxSize, idx: &[IdxSize]) -> IdxItem { +fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem { let (offset, len) = slice_offsets(offset, length, idx.len()); - (first + offset as IdxSize, idx[offset..offset + len].into()) + + // If slice isn't out of bounds, we replace first. + // If slice is oob, the `idx` vec will be empty and `first` will be ignored + if let Some(f) = idx.get(offset) { + first = *f; + } + // This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day. + (first, idx[offset..offset + len].into()) } fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] { diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index cc3447e1539c..0c2a775657d4 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -131,9 +131,9 @@ fn sort_by_groups_no_match_single<'a>( }, _ => Ok(None), }) - .collect_ca_with_dtype("", dtype) + .collect_ca_with_dtype(PlSmallStr::EMPTY, dtype) }); - let s = ca?.with_name(s_in.name()).into_series(); + let s = ca?.with_name(s_in.name().clone()).into_series(); ac_in.with_series(s, true, Some(expr))?; Ok(ac_in) } diff --git a/crates/polars-expr/src/expressions/ternary.rs b/crates/polars-expr/src/expressions/ternary.rs index b84e868efd35..ef12dcea0204 100644 --- a/crates/polars-expr/src/expressions/ternary.rs +++ b/crates/polars-expr/src/expressions/ternary.rs @@ -37,26 +37,23 @@ fn finish_as_iters<'a>( mut ac_falsy: AggregationContext<'a>, mut ac_mask: AggregationContext<'a>, ) -> PolarsResult> { - // SAFETY: unstable series never lives longer than the iterator. - let ca = unsafe { - ac_truthy - .iter_groups(false) - .zip(ac_falsy.iter_groups(false)) - .zip(ac_mask.iter_groups(false)) - .map(|((truthy, falsy), mask)| { - match (truthy, falsy, mask) { - (Some(truthy), Some(falsy), Some(mask)) => Some( - truthy - .as_ref() - .zip_with(mask.as_ref().bool()?, falsy.as_ref()), - ), - _ => None, - } - .transpose() - }) - .collect::>()? - .with_name(ac_truthy.series().name()) - }; + let ca = ac_truthy + .iter_groups(false) + .zip(ac_falsy.iter_groups(false)) + .zip(ac_mask.iter_groups(false)) + .map(|((truthy, falsy), mask)| { + match (truthy, falsy, mask) { + (Some(truthy), Some(falsy), Some(mask)) => Some( + truthy + .as_ref() + .zip_with(mask.as_ref().bool()?, falsy.as_ref()), + ), + _ => None, + } + .transpose() + }) + .collect::>()? + .with_name(ac_truthy.series().name().clone()); // Aggregation leaves only a single chunk. let arr = ca.downcast_iter().next().unwrap(); @@ -283,12 +280,12 @@ impl PhysicalExpr for TernaryExpr { let values = out.array_ref(0); let offsets = ac_target.series().list().unwrap().offsets()?; let inner_type = out.dtype(); - let data_type = LargeListArray::default_datatype(values.data_type().clone()); + let dtype = LargeListArray::default_datatype(values.dtype().clone()); // SAFETY: offsets are correct. - let out = LargeListArray::new(data_type, offsets, values.clone(), None); + let out = LargeListArray::new(dtype, offsets, values.clone(), None); - let mut out = ListChunked::with_chunk(truthy.name(), out); + let mut out = ListChunked::with_chunk(truthy.name().clone(), out); unsafe { out.to_logical(inner_type.clone()) }; if ac_target.series().list().unwrap()._can_fast_explode() { diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index c2ccf7028b03..2ea353cc7c52 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -10,7 +10,6 @@ use polars_ops::frame::join::{default_join_ids, private_left_join_multiple_keys, use polars_ops::frame::SeriesJoin; use polars_ops::prelude::*; use polars_plan::prelude::*; -use polars_utils::format_smartstring; use polars_utils::sort::perfect_sort; use polars_utils::sync::SyncPtr; use rayon::prelude::*; @@ -22,8 +21,8 @@ pub struct WindowExpr { /// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index pub(crate) group_by: Vec>, pub(crate) order_by: Option<(Arc, SortOptions)>, - pub(crate) apply_columns: Vec>, - pub(crate) out_name: Option>, + pub(crate) apply_columns: Vec, + pub(crate) out_name: Option, /// A function Expr. i.e. Mean, Median, Max, etc. pub(crate) function: Expr, pub(crate) phys_function: Arc, @@ -114,7 +113,7 @@ impl WindowExpr { // SAFETY: // we only have unique indices ranging from 0..len unsafe { perfect_sort(&POOL, &idx_mapping, &mut take_idx) }; - let idx = IdxCa::from_vec("", take_idx); + let idx = IdxCa::from_vec(PlSmallStr::EMPTY, take_idx); // SAFETY: // groups should always be in bounds. @@ -175,7 +174,7 @@ impl WindowExpr { let first = group.first(); let group = group_by_columns .iter() - .map(|s| format_smartstring!("{}", s.get(first as usize).unwrap())) + .map(|s| format!("{}", s.get(first as usize).unwrap())) .collect::>(); polars_bail!( expr = self.expr, ComputeError: @@ -407,7 +406,7 @@ impl PhysicalExpr for WindowExpr { if df.is_empty() { let field = self.phys_function.to_field(&df.schema())?; - return Ok(Series::full_null(field.name(), 0, field.data_type())); + return Ok(Series::full_null(field.name().clone(), 0, field.dtype())); } let group_by_columns = self @@ -497,11 +496,7 @@ impl PhysicalExpr for WindowExpr { }; // 2. create GroupBy object and apply aggregation - let apply_columns = self - .apply_columns - .iter() - .map(|s| s.as_ref().to_string()) - .collect(); + let apply_columns = self.apply_columns.clone(); // some window expressions need sorted groups // to make sure that the caches align we sort @@ -526,7 +521,7 @@ impl PhysicalExpr for WindowExpr { let mut out = ac.flat_naive().into_owned(); cache_gb(gb, state, &cache_key); if let Some(name) = &self.out_name { - out.rename(name.as_ref()); + out.rename(name.clone()); } Ok(out) }, @@ -534,7 +529,7 @@ impl PhysicalExpr for WindowExpr { let mut out = ac.aggregated().explode()?; cache_gb(gb, state, &cache_key); if let Some(name) = &self.out_name { - out.rename(name.as_ref()); + out.rename(name.clone()); } Ok(out) }, @@ -616,7 +611,7 @@ impl PhysicalExpr for WindowExpr { let mut out = materialize_column(&join_opt_ids, &out_column); if let Some(name) = &self.out_name { - out.rename(name.as_ref()); + out.rename(name.clone()); } if state.cache_window() { @@ -747,7 +742,7 @@ where // SAFETY: we have written all slots unsafe { values.set_len(len) } - ChunkedArray::new_vec(ca.name(), values).into_series() + ChunkedArray::new_vec(ca.name().clone(), values).into_series() } else { // We don't use a mutable bitmap as bits will have have race conditions! // A single byte might alias if we write from single threads. @@ -825,6 +820,6 @@ where values.into(), Some(validity), ); - Series::try_from((ca.name(), arr.boxed())).unwrap() + Series::try_from((ca.name().clone(), arr.boxed())).unwrap() } } diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 8e648a77f3d7..315fc123d158 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -251,9 +251,9 @@ fn create_physical_expr_inner( if apply_columns.is_empty() { if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Literal(_))) { - apply_columns.push(Arc::from("literal")) + apply_columns.push(PlSmallStr::from_static("literal")) } else if has_aexpr(function, expr_arena, |e| matches!(e, AExpr::Len)) { - apply_columns.push(Arc::from("len")) + apply_columns.push(PlSmallStr::from_static("len")) } else { let e = node_to_expr(function, expr_arena); polars_bail!( @@ -428,13 +428,13 @@ fn create_physical_expr_inner( }, Cast { expr, - data_type, + dtype, options, } => { let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; Ok(Arc::new(CastExpr { input: phys_expr, - data_type: data_type.clone(), + dtype: dtype.clone(), expr: node_to_expr(expression, expr_arena), options: *options, })) @@ -576,11 +576,5 @@ fn create_physical_expr_inner( node_to_expr(*input, expr_arena), ))) }, - Wildcard => { - polars_bail!(ComputeError: "wildcard column selection not supported at this point") - }, - Nth(n) => { - polars_bail!(ComputeError: "nth column selection not supported at this point (n={})", n) - }, } } diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs index f5a33aca1a0b..279d77d6eb67 100644 --- a/crates/polars-expr/src/reduce/convert.rs +++ b/crates/polars-expr/src/reduce/convert.rs @@ -2,53 +2,48 @@ use polars_core::error::feature_gated; use polars_plan::prelude::*; use polars_utils::arena::{Arena, Node}; -use super::extrema::*; +use super::len::LenReduce; +use super::mean::MeanReduce; +use super::min_max::{MaxReduce, MinReduce}; +#[cfg(feature = "propagate_nans")] +use super::nan_min_max::{NanMaxReduce, NanMinReduce}; use super::sum::SumReduce; use super::*; -use crate::reduce::mean::MeanReduce; - -pub fn can_convert_into_reduction(node: Node, expr_arena: &Arena) -> bool { - match expr_arena.get(node) { - AExpr::Agg(agg) => matches!( - agg, - IRAggExpr::Min { .. } - | IRAggExpr::Max { .. } - | IRAggExpr::Mean { .. } - | IRAggExpr::Sum(_) - ), - _ => false, - } -} +/// Converts a node into a reduction + its associated selector expression. pub fn into_reduction( node: Node, - expr_arena: &Arena, + expr_arena: &mut Arena, schema: &Schema, -) -> PolarsResult, Node)>> { - let e = expr_arena.get(node); - let field = e.to_field(schema, Context::Default, expr_arena)?; +) -> PolarsResult<(Box, Node)> { + let get_dt = |node| { + expr_arena + .get(node) + .to_dtype(schema, Context::Default, expr_arena) + }; let out = match expr_arena.get(node) { AExpr::Agg(agg) => match agg { - IRAggExpr::Sum(node) => ( - Box::new(SumReduce::new(field.dtype.clone())) as Box, - *node, + IRAggExpr::Sum(input) => ( + Box::new(SumReduce::new(get_dt(*input)?)) as Box, + *input, ), IRAggExpr::Min { propagate_nans, input, } => { - if *propagate_nans && field.dtype.is_float() { + let dt = get_dt(*input)?; + if *propagate_nans && dt.is_float() { feature_gated!("propagate_nans", { - let out: Box = match field.dtype { - DataType::Float32 => Box::new(MinNanReduce::::new()), - DataType::Float64 => Box::new(MinNanReduce::::new()), + let out: Box = match dt { + DataType::Float32 => Box::new(NanMinReduce::::new()), + DataType::Float64 => Box::new(NanMinReduce::::new()), _ => unreachable!(), }; (out, *input) }) } else { ( - Box::new(MinReduce::new(field.dtype.clone())) as Box, + Box::new(MinReduce::new(dt.clone())) as Box, *input, ) } @@ -57,26 +52,38 @@ pub fn into_reduction( propagate_nans, input, } => { - if *propagate_nans && field.dtype.is_float() { + let dt = get_dt(*input)?; + if *propagate_nans && dt.is_float() { feature_gated!("propagate_nans", { - let out: Box = match field.dtype { - DataType::Float32 => Box::new(MaxNanReduce::::new()), - DataType::Float64 => Box::new(MaxNanReduce::::new()), + let out: Box = match dt { + DataType::Float32 => Box::new(NanMaxReduce::::new()), + DataType::Float64 => Box::new(NanMaxReduce::::new()), _ => unreachable!(), }; (out, *input) }) } else { - (Box::new(MaxReduce::new(field.dtype.clone())) as _, *input) + (Box::new(MaxReduce::new(dt.clone())) as _, *input) } }, IRAggExpr::Mean(input) => { - let out: Box = Box::new(MeanReduce::new(field.dtype.clone())); + let out: Box = Box::new(MeanReduce::new(get_dt(*input)?)); (out, *input) }, - _ => return Ok(None), + _ => unreachable!(), + }, + AExpr::Len => { + // Compute length on the first column, or if none exist we'll never + // be called and correctly return 0 as length anyway. + let out: Box = Box::new(LenReduce::new()); + let expr = if let Some(first_column) = schema.iter_names().next() { + expr_arena.add(AExpr::Column(first_column.as_str().into())) + } else { + expr_arena.add(AExpr::Literal(LiteralValue::Null)) + }; + (out, expr) }, - _ => return Ok(None), + _ => unreachable!(), }; - Ok(Some(out)) + Ok(out) } diff --git a/crates/polars-expr/src/reduce/extrema.rs b/crates/polars-expr/src/reduce/extrema.rs deleted file mode 100644 index 5eee559e1588..000000000000 --- a/crates/polars-expr/src/reduce/extrema.rs +++ /dev/null @@ -1,249 +0,0 @@ -#[cfg(feature = "propagate_nans")] -use polars_core::datatypes::PolarsFloatType; -#[cfg(feature = "propagate_nans")] -use polars_ops::prelude::nan_propagating_aggregate; -#[cfg(feature = "propagate_nans")] -use polars_utils::min_max::MinMax; - -use super::*; - -#[derive(Clone)] -pub(super) struct MinReduce { - dtype: DataType, - value: Option, -} - -impl MinReduce { - pub(super) fn new(dtype: DataType) -> Self { - Self { dtype, value: None } - } - - fn update_impl(&mut self, other: &AnyValue<'static>) { - if let Some(value) = &mut self.value { - if other < value.value() { - value.update(other.clone()); - } - } else { - self.value = Some(Scalar::new(self.dtype.clone(), other.clone())) - } - } -} - -impl Reduction for MinReduce { - fn init_dyn(&self) -> Box { - Box::new(Self::new(self.dtype.clone())) - } - - fn reset(&mut self) { - *self = Self::new(self.dtype.clone()); - } - - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.min_reduce()?; - self.update_impl(sc.value()); - Ok(()) - } - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - if let Some(value) = &other.value { - self.update_impl(value.value()); - } - Ok(()) - } - - fn finalize(&mut self) -> PolarsResult { - if let Some(value) = self.value.take() { - Ok(value) - } else { - Ok(Scalar::new(self.dtype.clone(), AnyValue::Null)) - } - } - - fn as_any(&self) -> &dyn Any { - self - } -} -#[derive(Clone)] -pub(super) struct MaxReduce { - dtype: DataType, - value: Option, -} - -impl MaxReduce { - pub(super) fn new(dtype: DataType) -> Self { - Self { dtype, value: None } - } - fn update_impl(&mut self, other: &AnyValue<'static>) { - if let Some(value) = &mut self.value { - if other > value.value() { - value.update(other.clone()); - } - } else { - self.value = Some(Scalar::new(self.dtype.clone(), other.clone())) - } - } -} - -impl Reduction for MaxReduce { - fn init_dyn(&self) -> Box { - Box::new(Self::new(self.dtype.clone())) - } - fn reset(&mut self) { - *self = Self::new(self.dtype.clone()); - } - - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.max_reduce()?; - self.update_impl(sc.value()); - Ok(()) - } - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - - if let Some(value) = &other.value { - self.update_impl(value.value()); - } - Ok(()) - } - - fn finalize(&mut self) -> PolarsResult { - if let Some(value) = self.value.take() { - Ok(value) - } else { - Ok(Scalar::new(self.dtype.clone(), AnyValue::Null)) - } - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -#[cfg(feature = "propagate_nans")] -#[derive(Clone)] -pub(super) struct MaxNanReduce { - value: Option, -} - -#[cfg(feature = "propagate_nans")] -impl MaxNanReduce -where - T::Native: MinMax, -{ - pub(super) fn new() -> Self { - Self { value: None } - } - fn update_impl(&mut self, other: T::Native) { - if let Some(value) = self.value { - self.value = Some(MinMax::max_propagate_nan(value, other)); - } else { - self.value = Some(other); - } - } -} - -#[cfg(feature = "propagate_nans")] -impl Reduction for MaxNanReduce -where - T::Native: MinMax, -{ - fn init_dyn(&self) -> Box { - Box::new(Self::new()) - } - fn reset(&mut self) { - self.value = None; - } - - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - if let Some(v) = nan_propagating_aggregate::ca_nan_agg( - batch.unpack::().unwrap(), - MinMax::max_propagate_nan, - ) { - self.update_impl(v) - } - Ok(()) - } - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - - if let Some(value) = &other.value { - self.update_impl(*value); - } - Ok(()) - } - - fn finalize(&mut self) -> PolarsResult { - let av = AnyValue::from(self.value); - Ok(Scalar::new(T::get_dtype(), av)) - } - - fn as_any(&self) -> &dyn Any { - self - } -} -#[cfg(feature = "propagate_nans")] -#[derive(Clone)] -pub(super) struct MinNanReduce { - value: Option, -} - -#[cfg(feature = "propagate_nans")] -impl crate::reduce::extrema::MinNanReduce -where - T::Native: MinMax, -{ - pub(super) fn new() -> Self { - Self { value: None } - } - fn update_impl(&mut self, other: T::Native) { - if let Some(value) = self.value { - self.value = Some(MinMax::min_propagate_nan(value, other)); - } else { - self.value = Some(other); - } - } -} - -#[cfg(feature = "propagate_nans")] -impl Reduction for crate::reduce::extrema::MinNanReduce -where - T::Native: MinMax, -{ - fn init_dyn(&self) -> Box { - Box::new(Self::new()) - } - fn reset(&mut self) { - self.value = None; - } - - fn update(&mut self, batch: &Series) -> PolarsResult<()> { - if let Some(v) = nan_propagating_aggregate::ca_nan_agg( - batch.unpack::().unwrap(), - MinMax::min_propagate_nan, - ) { - self.update_impl(v) - } - Ok(()) - } - - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { - let other = other.as_any().downcast_ref::().unwrap(); - - if let Some(value) = &other.value { - self.update_impl(*value); - } - Ok(()) - } - - fn finalize(&mut self) -> PolarsResult { - let av = AnyValue::from(self.value); - Ok(Scalar::new(T::get_dtype(), av)) - } - - fn as_any(&self) -> &dyn Any { - self - } -} diff --git a/crates/polars-expr/src/reduce/len.rs b/crates/polars-expr/src/reduce/len.rs new file mode 100644 index 000000000000..1e11a505410d --- /dev/null +++ b/crates/polars-expr/src/reduce/len.rs @@ -0,0 +1,45 @@ +use polars_core::error::constants::LENGTH_LIMIT_MSG; + +use super::*; + +#[derive(Clone)] +pub struct LenReduce {} + +impl LenReduce { + pub fn new() -> Self { + Self {} + } +} + +impl Reduction for LenReduce { + fn new_reducer(&self) -> Box { + Box::new(LenReduceState { len: 0 }) + } +} + +pub struct LenReduceState { + len: u64, +} + +impl ReductionState for LenReduceState { + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + self.len += batch.len() as u64; + Ok(()) + } + + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + self.len += other.len; + Ok(()) + } + + fn finalize(&self) -> PolarsResult { + #[allow(clippy::useless_conversion)] + let as_idx: IdxSize = self.len.try_into().expect(LENGTH_LIMIT_MSG); + Ok(Scalar::new(IDX_DTYPE, as_idx.into())) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/mean.rs b/crates/polars-expr/src/reduce/mean.rs index 0d06974d956b..e8b19b342de6 100644 --- a/crates/polars-expr/src/reduce/mean.rs +++ b/crates/polars-expr/src/reduce/mean.rs @@ -1,67 +1,51 @@ -use polars_core::utils::Container; - use super::*; #[derive(Clone)] pub struct MeanReduce { - value: Option, - len: u64, dtype: DataType, } impl MeanReduce { - pub(crate) fn new(dtype: DataType) -> Self { - let value = None; - Self { - value, - len: 0, - dtype, - } - } - - fn update_impl(&mut self, value: &AnyValue<'static>, len: u64) { - let value = value.extract::().expect("phys numeric"); - if let Some(acc) = &mut self.value { - *acc += value; - self.len += len; - } else { - self.value = Some(value); - self.len = len; - } + pub fn new(dtype: DataType) -> Self { + Self { dtype } } } impl Reduction for MeanReduce { - fn init_dyn(&self) -> Box { - Box::new(Self::new(self.dtype.clone())) - } - fn reset(&mut self) { - self.value = None; - self.len = 0; + fn new_reducer(&self) -> Box { + Box::new(MeanReduceState { + dtype: self.dtype.clone(), + sum: 0.0, + count: 0, + }) } +} +pub struct MeanReduceState { + dtype: DataType, + sum: f64, + count: u64, +} + +impl ReductionState for MeanReduceState { fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.sum_reduce()?; - self.update_impl(sc.value(), batch.len() as u64); + let count = batch.len() as u64 - batch.null_count() as u64; + self.count += count; + self.sum += batch._sum_as_f64(); Ok(()) } - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - - match (self.value, other.value) { - (Some(l), Some(r)) => self.value = Some(l + r), - (None, Some(r)) => self.value = Some(r), - (Some(l), None) => self.value = Some(l), - (None, None) => self.value = None, - } - self.len += other.len; + self.sum += other.sum; + self.count += other.count; Ok(()) } - fn finalize(&mut self) -> PolarsResult { + fn finalize(&self) -> PolarsResult { + let val = (self.count > 0).then(|| self.sum / self.count as f64); Ok(polars_core::scalar::reduce::mean_reduce( - self.value.map(|v| v / self.len as f64), + val, self.dtype.clone(), )) } diff --git a/crates/polars-expr/src/reduce/min_max.rs b/crates/polars-expr/src/reduce/min_max.rs new file mode 100644 index 000000000000..ba011d7d95f0 --- /dev/null +++ b/crates/polars-expr/src/reduce/min_max.rs @@ -0,0 +1,111 @@ +use super::*; + +#[derive(Clone)] +pub struct MinReduce { + dtype: DataType, +} + +impl MinReduce { + pub fn new(dtype: DataType) -> Self { + Self { dtype } + } +} + +impl Reduction for MinReduce { + fn new_reducer(&self) -> Box { + Box::new(MinReduceState { + value: Scalar::new(self.dtype.clone(), AnyValue::Null), + }) + } +} + +struct MinReduceState { + value: Scalar, +} + +impl MinReduceState { + fn update_with_value(&mut self, other: &AnyValue<'static>) { + if self.value.is_null() + || !other.is_null() && (other < self.value.value() || self.value.is_nan()) + { + self.value.update(other.clone()); + } + } +} + +impl ReductionState for MinReduceState { + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + let sc = batch.min_reduce()?; + self.update_with_value(sc.value()); + Ok(()) + } + + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + self.update_with_value(other.value.value()); + Ok(()) + } + + fn finalize(&self) -> PolarsResult { + Ok(self.value.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[derive(Clone)] +pub struct MaxReduce { + dtype: DataType, +} + +impl MaxReduce { + pub fn new(dtype: DataType) -> Self { + Self { dtype } + } +} + +impl Reduction for MaxReduce { + fn new_reducer(&self) -> Box { + Box::new(MaxReduceState { + value: Scalar::new(self.dtype.clone(), AnyValue::Null), + }) + } +} + +struct MaxReduceState { + value: Scalar, +} + +impl MaxReduceState { + fn update_with_value(&mut self, other: &AnyValue<'static>) { + if self.value.is_null() + || !other.is_null() && (other > self.value.value() || self.value.is_nan()) + { + self.value.update(other.clone()); + } + } +} + +impl ReductionState for MaxReduceState { + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + let sc = batch.min_reduce()?; + self.update_with_value(sc.value()); + Ok(()) + } + + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + self.update_with_value(other.value.value()); + Ok(()) + } + + fn finalize(&self) -> PolarsResult { + Ok(self.value.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs index bb51ba5c8a8d..26f9749b4479 100644 --- a/crates/polars-expr/src/reduce/mod.rs +++ b/crates/polars-expr/src/reduce/mod.rs @@ -1,23 +1,27 @@ mod convert; -mod extrema; +mod len; mod mean; +mod min_max; +#[cfg(feature = "propagate_nans")] +mod nan_min_max; mod sum; use std::any::Any; -pub use convert::{can_convert_into_reduction, into_reduction}; +pub use convert::into_reduction; use polars_core::prelude::*; -#[allow(dead_code)] -pub trait Reduction: Any + Send { - // Creates a fresh reduction. - fn init_dyn(&self) -> Box; - - // Resets this reduction to the fresh initial state. - fn reset(&mut self); +pub trait Reduction: Send { + /// Create a new reducer for this Reduction. + fn new_reducer(&self) -> Box; +} +pub trait ReductionState: Any + Send { + /// Adds the given series into the reduction. fn update(&mut self, batch: &Series) -> PolarsResult<()>; + /// Adds the elements of the given series at the given indices into the reduction. + /// /// # Safety /// Implementations may elide bound checks. unsafe fn update_gathered(&mut self, batch: &Series, idx: &[IdxSize]) -> PolarsResult<()> { @@ -25,9 +29,12 @@ pub trait Reduction: Any + Send { self.update(&batch) } - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()>; + /// Combines this reduction with another. + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()>; - fn finalize(&mut self) -> PolarsResult; + /// Returns a final result from the reduction. + fn finalize(&self) -> PolarsResult; + /// Returns this ReductionState as a dyn Any. fn as_any(&self) -> &dyn Any; } diff --git a/crates/polars-expr/src/reduce/nan_min_max.rs b/crates/polars-expr/src/reduce/nan_min_max.rs new file mode 100644 index 000000000000..4a42ce37d3a5 --- /dev/null +++ b/crates/polars-expr/src/reduce/nan_min_max.rs @@ -0,0 +1,141 @@ +use std::marker::PhantomData; + +use polars_compute::min_max::MinMaxKernel; +use polars_core::datatypes::PolarsFloatType; +use polars_utils::min_max::MinMax; + +use super::*; + +#[derive(Clone)] +pub struct NanMinReduce { + _phantom: PhantomData, +} + +impl NanMinReduce { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl Reduction for NanMinReduce +where + F::Array: for<'a> MinMaxKernel = F::Native>, +{ + fn new_reducer(&self) -> Box { + Box::new(NanMinReduceState:: { value: None }) + } +} + +struct NanMinReduceState { + value: Option, +} + +impl NanMinReduceState { + fn update_with_value(&mut self, other: Option) { + if let Some(other) = other { + if let Some(value) = self.value { + self.value = Some(MinMax::min_propagate_nan(value, other)); + } else { + self.value = Some(other); + } + } + } +} + +impl ReductionState for NanMinReduceState +where + F::Array: for<'a> MinMaxKernel = F::Native>, +{ + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + let ca = batch.unpack::().unwrap(); + let reduced = ca + .downcast_iter() + .filter_map(MinMaxKernel::min_propagate_nan_kernel) + .reduce(MinMax::min_propagate_nan); + self.update_with_value(reduced); + Ok(()) + } + + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + self.update_with_value(other.value); + Ok(()) + } + + fn finalize(&self) -> PolarsResult { + Ok(Scalar::new(F::get_dtype(), AnyValue::from(self.value))) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[derive(Clone)] +pub struct NanMaxReduce { + _phantom: PhantomData, +} + +impl NanMaxReduce { + pub fn new() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl Reduction for NanMaxReduce +where + F::Array: for<'a> MinMaxKernel = F::Native>, +{ + fn new_reducer(&self) -> Box { + Box::new(NanMaxReduceState:: { value: None }) + } +} + +struct NanMaxReduceState { + value: Option, +} + +impl NanMaxReduceState { + fn update_with_value(&mut self, other: Option) { + if let Some(other) = other { + if let Some(value) = self.value { + self.value = Some(MinMax::max_propagate_nan(value, other)); + } else { + self.value = Some(other); + } + } + } +} + +impl ReductionState for NanMaxReduceState +where + F::Array: for<'a> MinMaxKernel = F::Native>, +{ + fn update(&mut self, batch: &Series) -> PolarsResult<()> { + let ca = batch.unpack::().unwrap(); + let reduced = ca + .downcast_iter() + .filter_map(MinMaxKernel::max_propagate_nan_kernel) + .reduce(MinMax::max_propagate_nan); + self.update_with_value(reduced); + Ok(()) + } + + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { + let other = other.as_any().downcast_ref::().unwrap(); + self.update_with_value(other.value); + Ok(()) + } + + fn finalize(&self) -> PolarsResult { + Ok(Scalar::new(F::get_dtype(), AnyValue::from(self.value))) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/crates/polars-expr/src/reduce/sum.rs b/crates/polars-expr/src/reduce/sum.rs index 9e1e0e4600e4..0f1d094ded3f 100644 --- a/crates/polars-expr/src/reduce/sum.rs +++ b/crates/polars-expr/src/reduce/sum.rs @@ -4,42 +4,54 @@ use super::*; #[derive(Clone)] pub struct SumReduce { - value: Scalar, + dtype: DataType, } impl SumReduce { - pub(crate) fn new(dtype: DataType) -> Self { - let value = Scalar::new(dtype, AnyValue::Null); - Self { value } - } - - fn update_impl(&mut self, value: &AnyValue<'static>) { - self.value.update(self.value.value().add(value)) + pub fn new(dtype: DataType) -> Self { + // We cast small dtypes up in the sum, we must also do this when + // returning the empty sum to be consistent. + use DataType::*; + let dtype = match dtype { + Boolean => IDX_DTYPE, + Int8 | UInt8 | Int16 | UInt16 => Int64, + dt => dt, + }; + Self { dtype } } } impl Reduction for SumReduce { - fn init_dyn(&self) -> Box { - Box::new(Self::new(self.value.dtype().clone())) + fn new_reducer(&self) -> Box { + let value = Scalar::new(self.dtype.clone(), AnyValue::zero_sum(&self.dtype)); + Box::new(SumReduceState { value }) } - fn reset(&mut self) { - let av = AnyValue::zero(self.value.dtype()); - self.value.update(av); +} + +struct SumReduceState { + value: Scalar, +} + +impl SumReduceState { + fn add_value(&mut self, other: &AnyValue<'_>) { + self.value.update(self.value.value().add(other)); } +} +impl ReductionState for SumReduceState { fn update(&mut self, batch: &Series) -> PolarsResult<()> { - let sc = batch.sum_reduce()?; - self.update_impl(sc.value()); + let reduced = batch.sum_reduce()?; + self.add_value(reduced.value()); Ok(()) } - fn combine(&mut self, other: &dyn Reduction) -> PolarsResult<()> { + fn combine(&mut self, other: &dyn ReductionState) -> PolarsResult<()> { let other = other.as_any().downcast_ref::().unwrap(); - self.update_impl(other.value.value()); + self.add_value(other.value.value()); Ok(()) } - fn finalize(&mut self) -> PolarsResult { + fn finalize(&self) -> PolarsResult { Ok(self.value.clone()) } diff --git a/crates/polars-expr/src/state/node_timer.rs b/crates/polars-expr/src/state/node_timer.rs index 95084eeb4fcb..8102aa8fcf83 100644 --- a/crates/polars-expr/src/state/node_timer.rs +++ b/crates/polars-expr/src/state/node_timer.rs @@ -42,20 +42,20 @@ impl NodeTimer { polars_ensure!(!ticks.is_empty(), ComputeError: "no data to time"); let start = ticks[0].0; ticks.push((self.query_start, start)); - let nodes_s = Series::new("node", nodes); + let nodes_s = Series::new(PlSmallStr::from_static("node"), nodes); let start: NoNull = ticks .iter() .map(|(start, _)| (start.duration_since(self.query_start)).as_micros() as u64) .collect(); let mut start = start.into_inner(); - start.rename("start"); + start.rename(PlSmallStr::from_static("start")); let end: NoNull = ticks .iter() .map(|(_, end)| (end.duration_since(self.query_start)).as_micros() as u64) .collect(); let mut end = end.into_inner(); - end.rename("end"); + end.rename(PlSmallStr::from_static("end")); let columns = vec![nodes_s, start.into_series(), end.into_series()]; let df = unsafe { DataFrame::new_no_checks(columns) }; diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs index 51635b2c0068..31a47aef3eef 100644 --- a/crates/polars-ffi/src/lib.rs +++ b/crates/polars-ffi/src/lib.rs @@ -29,6 +29,6 @@ unsafe fn import_array( schema: &ffi::ArrowSchema, ) -> PolarsResult { let field = ffi::import_field_from_c(schema)?; - let out = ffi::import_array_from_c(array, field.data_type)?; + let out = ffi::import_array_from_c(array, field.dtype)?; Ok(out) } diff --git a/crates/polars-ffi/src/version_0.rs b/crates/polars-ffi/src/version_0.rs index eb24542f0733..0fc29055f66d 100644 --- a/crates/polars-ffi/src/version_0.rs +++ b/crates/polars-ffi/src/version_0.rs @@ -54,7 +54,11 @@ unsafe extern "C" fn c_release_series_export(e: *mut SeriesExport) { } pub fn export_series(s: &Series) -> SeriesExport { - let field = ArrowField::new(s.name(), s.dtype().to_arrow(CompatLevel::newest()), true); + let field = ArrowField::new( + s.name().clone(), + s.dtype().to_arrow(CompatLevel::newest()), + true, + ); let schema = Box::new(ffi::export_field_to_c(&field)); let mut arrays = (0..s.chunks().len()) @@ -91,7 +95,7 @@ pub unsafe fn import_series(e: SeriesExport) -> PolarsResult { }) .collect::>>()?; - Series::try_from((field.name.as_str(), chunks)) + Series::try_from((field.name.clone(), chunks)) } /// # Safety @@ -144,7 +148,7 @@ mod test { #[test] fn test_ffi() { - let s = Series::new("a", [1, 2]); + let s = Series::new("a".into(), [1, 2]); let e = export_series(&s); unsafe { diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 84a113d96433..64259f78ad09 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -13,6 +13,7 @@ polars-core = { workspace = true } polars-error = { workspace = true } polars-json = { workspace = true, optional = true } polars-parquet = { workspace = true, optional = true } +polars-schema = { workspace = true } polars-time = { workspace = true, features = [], optional = true } polars-utils = { workspace = true, features = ['mmap'] } @@ -28,6 +29,7 @@ fast-float = { workspace = true, optional = true } flate2 = { workspace = true, optional = true } futures = { workspace = true, optional = true } glob = { version = "0.3" } +hashbrown = { workspace = true } itoa = { workspace = true, optional = true } memchr = { workspace = true } memmap = { workspace = true } @@ -40,10 +42,9 @@ regex = { workspace = true } reqwest = { workspace = true, optional = true } ryu = { workspace = true, optional = true } serde = { workspace = true, features = ["rc"], optional = true } -serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value", "std"], optional = true } +serde_json = { version = "1", optional = true } simd-json = { workspace = true, optional = true } simdutf8 = { workspace = true, optional = true } -smartstring = { workspace = true } tokio = { workspace = true, features = ["fs", "net", "rt-multi-thread", "time", "sync"], optional = true } tokio-util = { workspace = true, features = ["io", "io-util"], optional = true } url = { workspace = true, optional = true } @@ -63,11 +64,10 @@ json = [ "polars-json", "simd-json", "atoi_simd", - "serde_json", "dtype-struct", "csv", ] -serde = ["dep:serde", "polars-core/serde-lazy", "polars-parquet/serde"] +serde = ["dep:serde", "polars-core/serde-lazy", "polars-parquet/serde", "polars-utils/serde"] # support for arrows ipc file parsing ipc = ["arrow/io_ipc", "arrow/io_ipc_compression"] # support for arrows streaming ipc file parsing @@ -101,7 +101,7 @@ dtype-struct = ["polars-core/dtype-struct"] dtype-decimal = ["polars-core/dtype-decimal", "polars-json?/dtype-decimal"] fmt = ["polars-core/fmt"] lazy = [] -parquet = ["polars-parquet", "polars-parquet/compression"] +parquet = ["polars-parquet", "polars-parquet/compression", "polars-core/partition_by"] async = [ "async-trait", "futures", @@ -122,12 +122,11 @@ cloud = [ "reqwest", "http", ] -file_cache = ["async", "dep:blake3", "dep:fs4"] +file_cache = ["async", "dep:blake3", "dep:fs4", "serde_json", "cloud"] aws = ["object_store/aws", "cloud", "reqwest"] azure = ["object_store/azure", "cloud"] gcp = ["object_store/gcp", "cloud"] http = ["object_store/http", "cloud"] -partition = ["polars-core/partition_by"] temporal = ["dtype-datetime", "dtype-date", "dtype-time"] simd = [] python = ["polars-error/python"] diff --git a/crates/polars-io/src/avro/read.rs b/crates/polars-io/src/avro/read.rs index 6d410c74ef0e..e0823e6dd916 100644 --- a/crates/polars-io/src/avro/read.rs +++ b/crates/polars-io/src/avro/read.rs @@ -39,7 +39,7 @@ impl AvroReader { /// Get schema of the Avro File pub fn schema(&mut self) -> PolarsResult { let schema = self.arrow_schema()?; - Ok(Schema::from_iter(&schema.fields)) + Ok(Schema::from_arrow_schema(&schema)) } /// Get arrow schema of the avro File, this is faster than a polars schema. @@ -109,7 +109,7 @@ where } let (projection, projected_schema) = if let Some(projection) = self.projection { - let mut prj = vec![false; schema.fields.len()]; + let mut prj = vec![false; schema.len()]; for &index in projection.iter() { prj[index] = true; } @@ -118,8 +118,7 @@ where (None, schema.clone()) }; - let avro_reader = - avro::read::Reader::new(&mut self.reader, metadata, schema.fields, projection); + let avro_reader = avro::read::Reader::new(&mut self.reader, metadata, schema, projection); finish_reader( avro_reader, diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs index 435d703f2d80..5e034b55a80c 100644 --- a/crates/polars-io/src/cloud/adaptors.rs +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -11,11 +11,13 @@ use tokio::io::AsyncWriteExt; use super::CloudOptions; use crate::pl_async::get_runtime; -/// Adaptor which wraps the interface of [ObjectStore::BufWriter](https://docs.rs/object_store/latest/object_store/buffered/struct.BufWriter.html) -/// exposing a synchronous interface which implements `std::io::Write`. +/// Adaptor which wraps the interface of [ObjectStore::BufWriter] exposing a synchronous interface +/// which implements `std::io::Write`. /// /// This allows it to be used in sync code which would otherwise write to a simple File or byte stream, /// such as with `polars::prelude::CsvWriter`. +/// +/// [ObjectStore::BufWriter]: https://docs.rs/object_store/latest/object_store/buffered/struct.BufWriter.html pub struct CloudWriter { // Internal writer, constructed at creation writer: BufWriter, diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index de0968a80da0..efaab673f634 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -26,13 +26,13 @@ use polars_error::*; #[cfg(feature = "aws")] use polars_utils::cache::FastFixedCache; #[cfg(feature = "aws")] +use polars_utils::pl_str::PlSmallStr; +#[cfg(feature = "aws")] use regex::Regex; #[cfg(feature = "http")] use reqwest::header::HeaderMap; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "aws")] -use smartstring::alias::String as SmartString; #[cfg(feature = "cloud")] use url::Url; @@ -42,7 +42,7 @@ use crate::file_cache::get_env_file_cache_ttl; use crate::pl_async::with_concurrency_budget; #[cfg(feature = "aws")] -static BUCKET_REGION: Lazy>> = +static BUCKET_REGION: Lazy>> = Lazy::new(|| std::sync::Mutex::new(FastFixedCache::new(32))); /// The type of the config keys must satisfy the following requirements: @@ -277,7 +277,7 @@ impl CloudOptions { &mut builder, &[( Path::new("~/.aws/config"), - &[("region = (.*)\n", AmazonS3ConfigKey::Region)], + &[("region\\s*=\\s*(.*)\n", AmazonS3ConfigKey::Region)], )], ); read_config( @@ -285,9 +285,12 @@ impl CloudOptions { &[( Path::new("~/.aws/credentials"), &[ - ("aws_access_key_id = (.*)\n", AmazonS3ConfigKey::AccessKeyId), ( - "aws_secret_access_key = (.*)\n", + "aws_access_key_id\\s*=\\s*(.*)\n", + AmazonS3ConfigKey::AccessKeyId, + ), + ( + "aws_secret_access_key\\s*=\\s*(.*)\n", AmazonS3ConfigKey::SecretAccessKey, ), ], diff --git a/crates/polars-io/src/cloud/polars_object_store.rs b/crates/polars-io/src/cloud/polars_object_store.rs index f2744432bfa0..9738e0cbdbe4 100644 --- a/crates/polars-io/src/cloud/polars_object_store.rs +++ b/crates/polars-io/src/cloud/polars_object_store.rs @@ -16,6 +16,7 @@ use crate::pl_async::{ /// concurrent requests for the entire application. #[derive(Debug, Clone)] pub struct PolarsObjectStore(Arc); +pub type ObjectStorePath = object_store::path::Path; impl PolarsObjectStore { pub fn new(store: Arc) -> Self { @@ -71,7 +72,9 @@ impl PolarsObjectStore { while let Some(bytes) = stream.next().await { let bytes = bytes.map_err(to_compute_err)?; len += bytes.len(); - file.write(bytes.as_ref()).await.map_err(to_compute_err)?; + file.write_all(bytes.as_ref()) + .await + .map_err(to_compute_err)?; } PolarsResult::Ok(pl_async::Size::from(len as u64)) @@ -82,8 +85,31 @@ impl PolarsObjectStore { /// Fetch the metadata of the parquet file, do not memoize it. pub async fn head(&self, path: &Path) -> PolarsResult { - with_concurrency_budget(1, || self.0.head(path)) - .await - .map_err(to_compute_err) + with_concurrency_budget(1, || async { + let head_result = self.0.head(path).await; + + if head_result.is_err() { + // Pre-signed URLs forbid the HEAD method, but we can still retrieve the header + // information with a range 0-0 request. + let get_range_0_0_result = self + .0 + .get_opts( + path, + object_store::GetOptions { + range: Some((0..1).into()), + ..Default::default() + }, + ) + .await; + + if let Ok(v) = get_range_0_0_result { + return Ok(v.meta); + } + } + + head_result + }) + .await + .map_err(to_compute_err) } } diff --git a/crates/polars-io/src/csv/read/buffer.rs b/crates/polars-io/src/csv/read/buffer.rs index 26e9359a6000..712201ceaca6 100644 --- a/crates/polars-io/src/csv/read/buffer.rs +++ b/crates/polars-io/src/csv/read/buffer.rs @@ -147,7 +147,7 @@ where } pub struct Utf8Field { - name: String, + name: PlSmallStr, mutable: MutableBinaryViewArray, scratch: Vec, quote_char: u8, @@ -155,9 +155,14 @@ pub struct Utf8Field { } impl Utf8Field { - fn new(name: &str, capacity: usize, quote_char: Option, encoding: CsvEncoding) -> Self { + fn new( + name: PlSmallStr, + capacity: usize, + quote_char: Option, + encoding: CsvEncoding, + ) -> Self { Self { - name: name.to_string(), + name, mutable: MutableBinaryViewArray::with_capacity(capacity), scratch: vec![], quote_char: quote_char.unwrap_or(b'"'), @@ -254,7 +259,7 @@ pub struct CategoricalField { #[cfg(feature = "dtype-categorical")] impl CategoricalField { fn new( - name: &str, + name: PlSmallStr, capacity: usize, quote_char: Option, ordering: CategoricalOrdering, @@ -358,7 +363,7 @@ pub struct DatetimeField { #[cfg(any(feature = "dtype-datetime", feature = "dtype-date"))] impl DatetimeField { - fn new(name: &str, capacity: usize) -> Self { + fn new(name: PlSmallStr, capacity: usize) -> Self { let builder = PrimitiveChunkedBuilder::::new(name, capacity); Self { compiled: None, @@ -492,6 +497,7 @@ pub fn init_buffers( .iter() .map(|&i| { let (name, dtype) = schema.get_at_index(i).unwrap(); + let name = name.clone(); let builder = match dtype { &DataType::Boolean => Buffer::Boolean(BooleanChunkedBuilder::new(name, capacity)), #[cfg(feature = "dtype-i8")] @@ -625,7 +631,7 @@ impl Buffer { Buffer::Utf8(v) => { let arr = v.mutable.freeze(); - StringChunked::with_chunk(v.name.as_str(), arr).into_series() + StringChunked::with_chunk(v.name.clone(), arr).into_series() }, #[allow(unused_variables)] Buffer::Categorical(buf) => { diff --git a/crates/polars-io/src/csv/read/mod.rs b/crates/polars-io/src/csv/read/mod.rs index 969be1a58908..b9d48291f8ce 100644 --- a/crates/polars-io/src/csv/read/mod.rs +++ b/crates/polars-io/src/csv/read/mod.rs @@ -26,7 +26,7 @@ mod splitfields; mod utils; pub use options::{CommentPrefix, CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues}; -pub use parser::count_rows; +pub use parser::{count_rows, count_rows_from_slice}; pub use read_impl::batched::{BatchedCsvReader, OwnedBatchedCsvReader}; pub use reader::CsvReader; pub use schema_inference::infer_file_schema; diff --git a/crates/polars-io/src/csv/read/options.rs b/crates/polars-io/src/csv/read/options.rs index 2d10029975e2..83b356fabde8 100644 --- a/crates/polars-io/src/csv/read/options.rs +++ b/crates/polars-io/src/csv/read/options.rs @@ -2,8 +2,9 @@ use std::path::PathBuf; use std::sync::Arc; use polars_core::datatypes::{DataType, Field}; -use polars_core::schema::{IndexOfSchema, Schema, SchemaRef}; +use polars_core::schema::{Schema, SchemaRef}; use polars_error::PolarsResult; +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -21,7 +22,7 @@ pub struct CsvReadOptions { pub n_rows: Option, pub row_index: Option, // Column-wise options - pub columns: Option>, + pub columns: Option>, pub projection: Option>>, pub schema: Option, pub schema_overwrite: Option, @@ -146,7 +147,7 @@ impl CsvReadOptions { } /// Which columns to select. - pub fn with_columns(mut self, columns: Option>) -> Self { + pub fn with_columns(mut self, columns: Option>) -> Self { self.columns = columns; self } @@ -336,7 +337,7 @@ pub enum CommentPrefix { Single(u8), /// A string that indicates the start of a comment line. /// This allows for multiple characters to be used as a comment identifier. - Multi(Arc), + Multi(PlSmallStr), } impl CommentPrefix { @@ -346,8 +347,8 @@ impl CommentPrefix { } /// Creates a new `CommentPrefix` for the `Multi` variant. - pub fn new_multi(prefix: String) -> Self { - CommentPrefix::Multi(Arc::from(prefix.as_str())) + pub fn new_multi(prefix: PlSmallStr) -> Self { + CommentPrefix::Multi(prefix) } /// Creates a new `CommentPrefix` from a `&str`. @@ -356,7 +357,7 @@ impl CommentPrefix { let c = prefix.as_bytes()[0]; CommentPrefix::Single(c) } else { - CommentPrefix::Multi(Arc::from(prefix)) + CommentPrefix::Multi(PlSmallStr::from_str(prefix)) } } } @@ -371,11 +372,11 @@ impl From<&str> for CommentPrefix { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum NullValues { /// A single value that's used for all columns - AllColumnsSingle(String), + AllColumnsSingle(PlSmallStr), /// Multiple values that are used for all columns - AllColumns(Vec), + AllColumns(Vec), /// Tuples that map column names to null value of that column - Named(Vec<(String, String)>), + Named(Vec<(PlSmallStr, PlSmallStr)>), } impl NullValues { @@ -384,7 +385,7 @@ impl NullValues { NullValues::AllColumnsSingle(v) => NullValuesCompiled::AllColumnsSingle(v), NullValues::AllColumns(v) => NullValuesCompiled::AllColumns(v), NullValues::Named(v) => { - let mut null_values = vec!["".to_string(); schema.len()]; + let mut null_values = vec![PlSmallStr::from_static(""); schema.len()]; for (name, null_value) in v { let i = schema.try_index_of(&name)?; null_values[i] = null_value; @@ -398,11 +399,11 @@ impl NullValues { #[derive(Debug, Clone)] pub(super) enum NullValuesCompiled { /// A single value that's used for all columns - AllColumnsSingle(String), + AllColumnsSingle(PlSmallStr), // Multiple null values that are null for all columns - AllColumns(Vec), + AllColumns(Vec), /// A different null value per column, computed from `NullValues::Named` - Columns(Vec), + Columns(Vec), } impl NullValuesCompiled { diff --git a/crates/polars-io/src/csv/read/parser.rs b/crates/polars-io/src/csv/read/parser.rs index f96f6e04b1b2..282a304003a3 100644 --- a/crates/polars-io/src/csv/read/parser.rs +++ b/crates/polars-io/src/csv/read/parser.rs @@ -1,9 +1,10 @@ -use std::path::PathBuf; +use std::path::Path; use memchr::memchr2_iter; use num_traits::Pow; use polars_core::prelude::*; use polars_core::{config, POOL}; +use polars_error::feature_gated; use polars_utils::index::Bounded; use polars_utils::slice::GetSaferUnchecked; use rayon::prelude::*; @@ -18,7 +19,7 @@ use crate::utils::maybe_decompress_bytes; /// Read the number of rows without parsing columns /// useful for count(*) queries pub fn count_rows( - path: &PathBuf, + path: &Path, separator: u8, quote_char: Option, comment_prefix: Option<&CommentPrefix>, @@ -26,32 +27,47 @@ pub fn count_rows( has_header: bool, ) -> PolarsResult { let file = if is_cloud_url(path) || config::force_async() { - #[cfg(feature = "cloud")] - { + feature_gated!("cloud", { crate::file_cache::FILE_CACHE .get_entry(path.to_str().unwrap()) // Safety: This was initialized by schema inference. .unwrap() .try_open_assume_latest()? - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") - } + }) } else { polars_utils::open_file(path)? }; let mmap = unsafe { memmap::Mmap::map(&file).unwrap() }; let owned = &mut vec![]; - let mut reader_bytes = maybe_decompress_bytes(mmap.as_ref(), owned)?; + let reader_bytes = maybe_decompress_bytes(mmap.as_ref(), owned)?; + + count_rows_from_slice( + reader_bytes, + separator, + quote_char, + comment_prefix, + eol_char, + has_header, + ) +} - for _ in 0..reader_bytes.len() { - if reader_bytes[0] != eol_char { +/// Read the number of rows without parsing columns +/// useful for count(*) queries +pub fn count_rows_from_slice( + mut bytes: &[u8], + separator: u8, + quote_char: Option, + comment_prefix: Option<&CommentPrefix>, + eol_char: u8, + has_header: bool, +) -> PolarsResult { + for _ in 0..bytes.len() { + if bytes[0] != eol_char { break; } - reader_bytes = &reader_bytes[1..]; + bytes = &bytes[1..]; } const MIN_ROWS_PER_THREAD: usize = 1024; @@ -59,7 +75,7 @@ pub fn count_rows( // Determine if parallelism is beneficial and how many threads let n_threads = get_line_stats( - reader_bytes, + bytes, MIN_ROWS_PER_THREAD, eol_char, None, @@ -67,22 +83,16 @@ pub fn count_rows( quote_char, ) .map(|(mean, std)| { - let n_rows = (reader_bytes.len() as f32 / (mean - 0.01 * std)) as usize; + let n_rows = (bytes.len() as f32 / (mean - 0.01 * std)) as usize; (n_rows / MIN_ROWS_PER_THREAD).clamp(1, max_threads) }) .unwrap_or(1); - let file_chunks: Vec<(usize, usize)> = get_file_chunks( - reader_bytes, - n_threads, - None, - separator, - quote_char, - eol_char, - ); + let file_chunks: Vec<(usize, usize)> = + get_file_chunks(bytes, n_threads, None, separator, quote_char, eol_char); let iter = file_chunks.into_par_iter().map(|(start, stop)| { - let local_bytes = &reader_bytes[start..stop]; + let local_bytes = &bytes[start..stop]; let row_iterator = SplitLines::new(local_bytes, quote_char.unwrap_or(b'"'), eol_char); if comment_prefix.is_some() { Ok(row_iterator @@ -261,12 +271,6 @@ pub(super) fn skip_whitespace(input: &[u8]) -> &[u8] { skip_condition(input, is_whitespace) } -#[inline] -/// Can be used to skip whitespace, but exclude the separator -pub(super) fn skip_whitespace_exclude(input: &[u8], exclude: u8) -> &[u8] { - skip_condition(input, |b| b != exclude && (is_whitespace(b))) -} - #[inline] pub(super) fn skip_line_ending(input: &[u8], eol_char: u8) -> &[u8] { skip_condition(input, |b| is_line_ending(b, eol_char)) diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs index 27328114b352..eabd1ca6a2f0 100644 --- a/crates/polars-io/src/csv/read/read_impl.rs +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -15,7 +15,7 @@ use super::buffer::init_buffers; use super::options::{CommentPrefix, CsvEncoding, NullValues, NullValuesCompiled}; use super::parser::{ get_line_stats, is_comment_line, next_line_position, next_line_position_naive, parse_lines, - skip_bom, skip_line_ending, skip_this_line, skip_whitespace_exclude, + skip_bom, skip_line_ending, skip_this_line, }; use super::schema_inference::{check_decimal_comma, infer_file_schema}; #[cfg(any(feature = "decompress", feature = "decompress-fast"))] @@ -35,7 +35,7 @@ pub(crate) fn cast_columns( ignore_errors: bool, ) -> PolarsResult<()> { let cast_fn = |s: &Series, fld: &Field| { - let out = match (s.dtype(), fld.data_type()) { + let out = match (s.dtype(), fld.dtype()) { #[cfg(feature = "temporal")] (DataType::String, DataType::Date) => s .str() @@ -74,7 +74,7 @@ pub(crate) fn cast_columns( df.get_columns() .into_par_iter() .map(|s| { - if let Some(fld) = to_cast.iter().find(|fld| fld.name().as_str() == s.name()) { + if let Some(fld) = to_cast.iter().find(|fld| fld.name() == s.name()) { cast_fn(s, fld) } else { Ok(s.clone()) @@ -150,7 +150,7 @@ impl<'a> CoreReader<'a> { has_header: bool, ignore_errors: bool, schema: Option, - columns: Option>, + columns: Option>, encoding: CsvEncoding, mut n_threads: Option, schema_overwrite: Option, @@ -278,8 +278,9 @@ impl<'a> CoreReader<'a> { ) -> PolarsResult<(&'b [u8], Option)> { let starting_point_offset = bytes.as_ptr() as usize; - // Skip all leading white space and the occasional utf8-bom - bytes = skip_whitespace_exclude(skip_bom(bytes), self.separator); + // Skip utf8 byte-order-mark (BOM) + bytes = skip_bom(bytes); + // \n\n can be a empty string row of a single column // in other cases we skip it. if self.schema.len() > 1 { @@ -495,7 +496,7 @@ impl<'a> CoreReader<'a> { ) }; if let Some(ref row_index) = self.row_index { - df.insert_column(0, Series::new_empty(&row_index.name, &IDX_DTYPE))?; + df.insert_column(0, Series::new_empty(row_index.name.clone(), &IDX_DTYPE))?; } return Ok(df); } @@ -558,7 +559,7 @@ impl<'a> CoreReader<'a> { let mut local_df = unsafe { DataFrame::new_no_checks(columns) }; let current_row_count = local_df.height() as IdxSize; if let Some(rc) = &self.row_index { - local_df.with_row_index_mut(&rc.name, Some(rc.offset)); + local_df.with_row_index_mut(rc.name.clone(), Some(rc.offset)); }; cast_columns(&mut local_df, &self.to_cast, false, self.ignore_errors)?; @@ -616,7 +617,7 @@ impl<'a> CoreReader<'a> { cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; if let Some(rc) = &self.row_index { - df.with_row_index_mut(&rc.name, Some(rc.offset)); + df.with_row_index_mut(rc.name.clone(), Some(rc.offset)); } let n_read = df.height() as IdxSize; Ok((df, n_read)) @@ -665,7 +666,7 @@ impl<'a> CoreReader<'a> { cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; if let Some(rc) = &self.row_index { - df.with_row_index_mut(&rc.name, Some(rc.offset)); + df.with_row_index_mut(rc.name.clone(), Some(rc.offset)); } let n_read = df.height() as IdxSize; (df, n_read) diff --git a/crates/polars-io/src/csv/read/read_impl/batched.rs b/crates/polars-io/src/csv/read/read_impl/batched.rs index c4be765648cb..3bf6e2dd4e32 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched.rs @@ -258,7 +258,7 @@ impl<'a> BatchedCsvReader<'a> { cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; if let Some(rc) = &self.row_index { - df.with_row_index_mut(&rc.name, Some(rc.offset)); + df.with_row_index_mut(rc.name.clone(), Some(rc.offset)); } Ok(df) }) diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs index c45e18f3c098..49fb576fff8a 100644 --- a/crates/polars-io/src/csv/read/reader.rs +++ b/crates/polars-io/src/csv/read/reader.rs @@ -168,7 +168,7 @@ impl CsvReader { .map(|mut fld| { use DataType::*; - match fld.data_type() { + match fld.dtype() { Time => { self.options.fields_to_cast.push(fld.clone()); fld.coerce(String); @@ -304,7 +304,7 @@ where let schema = dtypes .iter() .zip(df.get_column_names()) - .map(|(dtype, name)| Field::new(name, dtype.clone())) + .map(|(dtype, name)| Field::new(name.clone(), dtype.clone())) .collect::(); Arc::new(schema) diff --git a/crates/polars-io/src/csv/read/schema_inference.rs b/crates/polars-io/src/csv/read/schema_inference.rs index 189c54501c12..50942d8b5d56 100644 --- a/crates/polars-io/src/csv/read/schema_inference.rs +++ b/crates/polars-io/src/csv/read/schema_inference.rs @@ -6,6 +6,7 @@ use polars_core::prelude::*; use polars_time::chunkedarray::string::infer as date_infer; #[cfg(feature = "polars-time")] use polars_time::prelude::string::Pattern; +use polars_utils::format_pl_smallstr; use polars_utils::slice::GetSaferUnchecked; use super::options::{CommentPrefix, CsvEncoding, NullValues}; @@ -129,9 +130,10 @@ pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bo DataType::Datetime(TimeUnit::Microseconds, None) }, Pattern::DateYMD | Pattern::DateDMY => DataType::Date, - Pattern::DatetimeYMDZ => { - DataType::Datetime(TimeUnit::Microseconds, Some("UTC".to_string())) - }, + Pattern::DatetimeYMDZ => DataType::Datetime( + TimeUnit::Microseconds, + Some(PlSmallStr::from_static("UTC")), + ), }, None => DataType::String, } @@ -162,9 +164,10 @@ pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bo DataType::Datetime(TimeUnit::Microseconds, None) }, Pattern::DateYMD | Pattern::DateDMY => DataType::Date, - Pattern::DatetimeYMDZ => { - DataType::Datetime(TimeUnit::Microseconds, Some("UTC".to_string())) - }, + Pattern::DatetimeYMDZ => DataType::Datetime( + TimeUnit::Microseconds, + Some(PlSmallStr::from_static("UTC")), + ), }, None => DataType::String, } @@ -241,7 +244,7 @@ fn infer_file_schema_inner( } // now that we've found the first non-comment line we parse the headers, or we create a header - let headers: Vec = if let Some(mut header_line) = first_line { + let headers: Vec = if let Some(mut header_line) = first_line { let len = header_line.len(); if len > 1 { // remove carriage return @@ -272,9 +275,9 @@ fn infer_file_schema_inner( for name in &headers { let count = header_names.entry(name.as_ref()).or_insert(0usize); if *count != 0 { - final_headers.push(format!("{}_duplicated_{}", name, *count - 1)) + final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1)) } else { - final_headers.push(name.to_string()) + final_headers.push(PlSmallStr::from_str(name)) } *count += 1; } @@ -282,8 +285,8 @@ fn infer_file_schema_inner( } else { byterecord .enumerate() - .map(|(i, _s)| format!("column_{}", i + 1)) - .collect::>() + .map(|(i, _s)| format_pl_smallstr!("column_{}", i + 1)) + .collect::>() } } else if has_header && !bytes.is_empty() && recursion_count == 0 { // there was no new line char. So we copy the whole buf and add one @@ -311,7 +314,7 @@ fn infer_file_schema_inner( decimal_comma, ); } else if !raise_if_empty { - return Ok((Schema::new(), 0, 0)); + return Ok((Schema::default(), 0, 0)); } else { polars_bail!(NoData: "empty CSV"); }; @@ -395,7 +398,7 @@ fn infer_file_schema_inner( } }, Some(NullValues::AllColumnsSingle(name)) => { - if s.as_ref() != name { + if s.as_ref() != name.as_str() { Some(infer_field_schema(&s, try_parse_dates, decimal_comma)) } else { None @@ -405,10 +408,10 @@ fn infer_file_schema_inner( // SAFETY: // we iterate over headers length. let current_name = unsafe { headers.get_unchecked_release(i) }; - let null_name = &names.iter().find(|name| &name.0 == current_name); + let null_name = &names.iter().find(|name| name.0 == current_name); if let Some(null_name) = null_name { - if null_name.1 != s.as_ref() { + if null_name.1.as_str() != s.as_ref() { Some(infer_field_schema(&s, try_parse_dates, decimal_comma)) } else { None @@ -448,7 +451,7 @@ fn infer_file_schema_inner( if let Some(schema_overwrite) = schema_overwrite { if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) { - fields.push(Field::new(name, dtype.clone())); + fields.push(Field::new(name.clone(), dtype.clone())); continue; } @@ -456,7 +459,7 @@ fn infer_file_schema_inner( // execute only if schema is complete if schema_overwrite.len() == header_length { if let Some((name, dtype)) = schema_overwrite.get_at_index(i) { - fields.push(Field::new(name, dtype.clone())); + fields.push(Field::new(name.clone(), dtype.clone())); continue; } } @@ -464,7 +467,7 @@ fn infer_file_schema_inner( let possibilities = &column_types[i]; let dtype = finish_infer_field_schema(possibilities); - fields.push(Field::new(field_name, dtype)); + fields.push(Field::new(field_name.clone(), dtype)); } // if there is a single line after the header without an eol // we copy the bytes add an eol and rerun this function @@ -502,7 +505,7 @@ fn infer_file_schema_inner( pub(super) fn check_decimal_comma(decimal_comma: bool, separator: u8) -> PolarsResult<()> { if decimal_comma { - polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' quote char") + polars_ensure!(b',' != separator, InvalidOperation: "'decimal_comma' argument cannot be combined with ',' separator") } Ok(()) } diff --git a/crates/polars-io/src/csv/read/splitfields.rs b/crates/polars-io/src/csv/read/splitfields.rs index d3aed398c742..59f9bcd53bd8 100644 --- a/crates/polars-io/src/csv/read/splitfields.rs +++ b/crates/polars-io/src/csv/read/splitfields.rs @@ -145,16 +145,6 @@ mod inner { const SIMD_SIZE: usize = 16; type SimdVec = u8x16; - #[inline] - unsafe fn simple_argmax(arr: &[bool; SIMD_SIZE]) -> usize { - for (i, item) in arr.iter().enumerate() { - if *item { - return i; - } - } - unreachable!(); - } - /// An adapted version of std::iter::Split. /// This exists solely because we cannot split the lines naively as pub(crate) struct SplitFields<'a> { @@ -276,26 +266,21 @@ mod inner { let bytes = unsafe { self.v.get_unchecked_release(total_idx..) }; if bytes.len() > SIMD_SIZE { - unsafe { - let lane: [u8; SIMD_SIZE] = bytes + let lane: [u8; SIMD_SIZE] = unsafe { + bytes .get_unchecked(0..SIMD_SIZE) .try_into() - .unwrap_unchecked_release(); - let simd_bytes = SimdVec::from(lane); - let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char); - let has_separator = simd_bytes.simd_eq(self.simd_separator); - let has_any = has_separator.bitor(has_eol_char); - if has_any.any() { - // soundness we can transmute because we have the same alignment - let has_any = std::mem::transmute::< - Mask<_, SIMD_SIZE>, - [bool; SIMD_SIZE], - >(has_any); - total_idx += simple_argmax(&has_any); - break; - } else { - total_idx += SIMD_SIZE; - } + .unwrap_unchecked_release() + }; + let simd_bytes = SimdVec::from(lane); + let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char); + let has_separator = simd_bytes.simd_eq(self.simd_separator); + let has_any = has_separator.bitor(has_eol_char); + if let Some(idx) = has_any.first_set() { + total_idx += idx; + break; + } else { + total_idx += SIMD_SIZE; } } else { match bytes.iter().position(|&c| self.eof_oel(c)) { @@ -317,7 +302,7 @@ mod inner { }; unsafe { - debug_assert!(pos <= self.v.len()); + debug_assert!(pos < self.v.len()); // SAFETY: // we are in bounds let ret = Some((self.v.get_unchecked(..pos), needs_escaping)); diff --git a/crates/polars-io/src/csv/write/writer.rs b/crates/polars-io/src/csv/write/writer.rs index 9369dacbe6da..f3017ce189ec 100644 --- a/crates/polars-io/src/csv/write/writer.rs +++ b/crates/polars-io/src/csv/write/writer.rs @@ -2,7 +2,7 @@ use std::io::Write; use std::num::NonZeroUsize; use polars_core::frame::DataFrame; -use polars_core::schema::{IndexOfSchema, Schema}; +use polars_core::schema::Schema; use polars_core::POOL; use polars_error::PolarsResult; @@ -49,9 +49,13 @@ where if self.bom { write_bom(&mut self.buffer)?; } - let names = df.get_column_names(); + let names = df + .get_column_names() + .into_iter() + .map(|x| x.as_str()) + .collect::>(); if self.header { - write_header(&mut self.buffer, &names, &self.options)?; + write_header(&mut self.buffer, names.as_slice(), &self.options)?; } write( &mut self.buffer, @@ -193,8 +197,16 @@ impl BatchedWriter { if !self.has_written_header { self.has_written_header = true; - let names = df.get_column_names(); - write_header(&mut self.writer.buffer, &names, &self.writer.options)?; + let names = df + .get_column_names() + .into_iter() + .map(|x| x.as_str()) + .collect::>(); + write_header( + &mut self.writer.buffer, + names.as_slice(), + &self.writer.options, + )?; } write( @@ -216,7 +228,11 @@ impl BatchedWriter { if !self.has_written_header { self.has_written_header = true; - let names = self.schema.get_names(); + let names = self + .schema + .iter_names() + .map(|x| x.as_str()) + .collect::>(); write_header(&mut self.writer.buffer, &names, &self.writer.options)?; }; diff --git a/crates/polars-io/src/hive.rs b/crates/polars-io/src/hive.rs index ddf1d8973b3e..b027e6d1d054 100644 --- a/crates/polars-io/src/hive.rs +++ b/crates/polars-io/src/hive.rs @@ -1,5 +1,4 @@ use polars_core::frame::DataFrame; -use polars_core::schema::IndexOfSchema; use polars_core::series::Series; /// Materializes hive partitions. @@ -9,9 +8,9 @@ use polars_core::series::Series; /// # Safety /// /// num_rows equals the height of the df when the df height is non-zero. -pub(crate) fn materialize_hive_partitions( +pub(crate) fn materialize_hive_partitions( df: &mut DataFrame, - reader_schema: &S, + reader_schema: &polars_schema::Schema, hive_partition_columns: Option<&[Series]>, num_rows: usize, ) { diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index e3c557eac1f3..feaea44f5417 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -12,8 +12,8 @@ //! use std::io::Cursor; //! //! -//! let s0 = Series::new("days", &[0, 1, 2, 3, 4]); -//! let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); +//! let s0 = Series::new("days".into(), &[0, 1, 2, 3, 4]); +//! let s1 = Series::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); //! let mut df = DataFrame::new(vec![s0, s1]).unwrap(); //! //! // Create an in memory file handler. @@ -51,9 +51,7 @@ use crate::RowIndex; #[derive(Clone, Debug, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct IpcScanOptions { - pub memory_map: bool, -} +pub struct IpcScanOptions; /// Read Arrows IPC format into a DataFrame /// @@ -81,7 +79,7 @@ pub struct IpcReader { pub(super) projection: Option>, pub(crate) columns: Option>, hive_partition_columns: Option>, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, pub(super) row_index: Option, // Stores the as key semaphore to make sure we don't write to the memory mapped file. pub(super) memory_map: Option, @@ -136,7 +134,7 @@ impl IpcReader { pub fn with_include_file_path( mut self, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, ) -> Self { self.include_file_path = include_file_path; self @@ -300,7 +298,7 @@ impl SerReader for IpcReader { if let Some((col, value)) = include_file_path { unsafe { - df.with_column_unchecked(StringChunked::full(&col, &value, row_count).into_series()) + df.with_column_unchecked(StringChunked::full(col, &value, row_count).into_series()) }; } diff --git a/crates/polars-io/src/ipc/ipc_reader_async.rs b/crates/polars-io/src/ipc/ipc_reader_async.rs index 9d392575e956..089dceaf9a8a 100644 --- a/crates/polars-io/src/ipc/ipc_reader_async.rs +++ b/crates/polars-io/src/ipc/ipc_reader_async.rs @@ -5,8 +5,9 @@ use object_store::path::Path; use object_store::ObjectMeta; use polars_core::datatypes::IDX_DTYPE; use polars_core::frame::DataFrame; -use polars_core::schema::Schema; +use polars_core::schema::{Schema, SchemaExt}; use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use crate::cloud::{ build_object_store, object_path_from_str, CloudLocation, CloudOptions, PolarsObjectStore, @@ -27,7 +28,7 @@ pub struct IpcReaderAsync { #[derive(Default, Clone)] pub struct IpcReadOptions { // Names of the columns to include in the output. - projection: Option>, + projection: Option>, // The maximum number of rows to include in the output. row_limit: Option, @@ -40,7 +41,7 @@ pub struct IpcReadOptions { } impl IpcReadOptions { - pub fn with_projection(mut self, projection: Option>) -> Self { + pub fn with_projection(mut self, projection: Option>) -> Self { self.projection = projection; self } @@ -141,7 +142,7 @@ impl IpcReaderAsync { Some(projection) => { fn prepare_schema(mut schema: Schema, row_index: Option<&RowIndex>) -> Schema { if let Some(rc) = row_index { - let _ = schema.insert_at_index(0, rc.name.as_ref().into(), IDX_DTYPE); + let _ = schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE); } schema } @@ -156,7 +157,10 @@ impl IpcReaderAsync { &fetched_metadata }; - let schema = prepare_schema((&metadata.schema).into(), options.row_index.as_ref()); + let schema = prepare_schema( + Schema::from_arrow_schema(metadata.schema.as_ref()), + options.row_index.as_ref(), + ); let hive_partitions = None; diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index c8429e1b2d80..545f19168f9f 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -13,8 +13,8 @@ //! use std::io::Cursor; //! //! -//! let s0 = Series::new("days", &[0, 1, 2, 3, 4]); -//! let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); +//! let s0 = Series::new("days".into(), &[0, 1, 2, 3, 4]); +//! let s1 = Series::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); //! let mut df = DataFrame::new(vec![s0, s1]).unwrap(); //! //! // Create an in memory file handler. @@ -76,7 +76,7 @@ pub struct IpcStreamReader { impl IpcStreamReader { /// Get schema of the Ipc Stream File pub fn schema(&mut self) -> PolarsResult { - Ok(Schema::from_iter(&self.metadata()?.schema.fields)) + Ok(Schema::from_arrow_schema(&self.metadata()?.schema)) } /// Get arrow schema of the Ipc Stream File, this is faster than creating a polars schema. diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs index 854bd4c8d9d7..f0343642482e 100644 --- a/crates/polars-io/src/ipc/mmap.rs +++ b/crates/polars-io/src/ipc/mmap.rs @@ -3,9 +3,10 @@ use arrow::io::ipc::read::{Dictionaries, FileMetadata}; use arrow::mmap::{mmap_dictionaries_unchecked, mmap_unchecked}; use arrow::record_batch::RecordBatch; use polars_core::prelude::*; +use polars_utils::mmap::MMapSemaphore; use super::ipc_file::IpcReader; -use crate::mmap::{MMapSemaphore, MmapBytesReader}; +use crate::mmap::MmapBytesReader; use crate::predicates::PhysicalIoExpr; use crate::shared::{finish_reader, ArrowReader}; use crate::utils::{apply_projection, columns_to_projection}; @@ -15,17 +16,9 @@ impl IpcReader { &mut self, predicate: Option>, ) -> PolarsResult { - #[cfg(target_family = "unix")] - use std::os::unix::fs::MetadataExt; match self.reader.to_file() { Some(file) => { - #[cfg(target_family = "unix")] - let metadata = file.metadata()?; - let mmap = unsafe { memmap::Mmap::map(file).unwrap() }; - #[cfg(target_family = "unix")] - let semaphore = MMapSemaphore::new(metadata.dev(), metadata.ino(), mmap); - #[cfg(not(target_family = "unix"))] - let semaphore = MMapSemaphore::new(mmap); + let semaphore = MMapSemaphore::new_from_file(file)?; let metadata = read::read_file_metadata(&mut std::io::Cursor::new(semaphore.as_ref()))?; diff --git a/crates/polars-io/src/json/infer.rs b/crates/polars-io/src/json/infer.rs index 9cd82721d156..0ff83225e97f 100644 --- a/crates/polars-io/src/json/infer.rs +++ b/crates/polars-io/src/json/infer.rs @@ -22,7 +22,7 @@ pub(crate) fn json_values_to_supertype( .unwrap_or_else(|| polars_bail!(ComputeError: "could not infer data-type")) } -pub(crate) fn data_types_to_supertype>( +pub(crate) fn dtypes_to_supertype>( datatypes: I, ) -> PolarsResult { datatypes diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index 99dbd53ffa5d..1a8f9eb8f5a4 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -71,6 +71,7 @@ use std::ops::Deref; use arrow::legacy::conversion::chunk_to_struct; use polars_core::error::to_compute_err; use polars_core::prelude::*; +use polars_error::{polars_bail, PolarsResult}; use polars_json::json::write::FallibleStreamingIterator; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -86,9 +87,11 @@ pub struct JsonWriterOptions { pub maintain_order: bool, } -/// The format to use to write the DataFrame to JSON: `Json` (a JSON array) or `JsonLines` (each row output on a -/// separate line). In either case, each row is serialized as a JSON object whose keys are the column names and whose -/// values are the row's corresponding values. +/// The format to use to write the DataFrame to JSON: `Json` (a JSON array) +/// or `JsonLines` (each row output on a separate line). +/// +/// In either case, each row is serialized as a JSON object whose keys are the column names and +/// whose values are the row's corresponding values. pub enum JsonFormat { /// A single JSON array containing each DataFrame row as an object. The length of the array is the number of rows in /// the DataFrame. @@ -216,12 +219,23 @@ where ignore_errors: bool, infer_schema_len: Option, batch_size: NonZeroUsize, - projection: Option>, + projection: Option>, schema: Option, schema_overwrite: Option<&'a Schema>, json_format: JsonFormat, } +pub fn remove_bom(bytes: &[u8]) -> PolarsResult<&[u8]> { + if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) { + // UTF-8 BOM + Ok(&bytes[3..]) + } else if bytes.starts_with(&[0xFE, 0xFF]) || bytes.starts_with(&[0xFF, 0xFE]) { + // UTF-16 BOM + polars_bail!(ComputeError: "utf-16 not supported") + } else { + Ok(bytes) + } +} impl<'a, R> SerReader for JsonReader<'a, R> where R: MmapBytesReader, @@ -251,8 +265,9 @@ where /// incompatible types in the input. In the event that a column contains mixed dtypes, is it unspecified whether an /// error is returned or whether elements of incompatible dtypes are replaced with `null`. fn finish(mut self) -> PolarsResult { - let rb: ReaderBytes = (&mut self.reader).into(); - + let pre_rb: ReaderBytes = (&mut self.reader).into(); + let bytes = remove_bom(pre_rb.deref())?; + let rb = ReaderBytes::Borrowed(bytes); let out = match self.json_format { JsonFormat::Json => { polars_ensure!(!self.ignore_errors, InvalidOperation: "'ignore_errors' only supported in ndjson"); @@ -286,13 +301,13 @@ where polars_bail!(ComputeError: "can only deserialize json objects") }; - let mut schema = Schema::from_iter(fields.iter()); + let mut schema = Schema::from_iter(fields.iter().map(Into::::into)); overwrite_schema(&mut schema, overwrite)?; DataType::Struct( schema .into_iter() - .map(|(name, dt)| Field::new(&name, dt)) + .map(|(name, dt)| Field::new(name, dt)) .collect(), ) .to_arrow(CompatLevel::newest()) @@ -303,7 +318,9 @@ where let dtype = if let BorrowedValue::Array(_) = &json_value { ArrowDataType::LargeList(Box::new(arrow::datatypes::Field::new( - "item", dtype, true, + PlSmallStr::from_static("item"), + dtype, + true, ))) } else { dtype @@ -340,8 +357,8 @@ where }?; // TODO! Ensure we don't materialize the columns we don't need - if let Some(proj) = &self.projection { - out.select(proj) + if let Some(proj) = self.projection.as_deref() { + out.select(proj.iter().cloned()) } else { Ok(out) } @@ -390,7 +407,7 @@ where /// /// Setting `projection` to the columns you want to keep is more efficient than deserializing all of the columns and /// then dropping the ones you don't want. - pub fn with_projection(mut self, projection: Option>) -> Self { + pub fn with_projection(mut self, projection: Option>) -> Self { self.projection = projection; self } diff --git a/crates/polars-io/src/lib.rs b/crates/polars-io/src/lib.rs index 5aa6e7fcebab..f3540f4e13fd 100644 --- a/crates/polars-io/src/lib.rs +++ b/crates/polars-io/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![cfg_attr(feature = "simd", feature(portable_simd))] #![allow(ambiguous_glob_reexports)] +extern crate core; #[cfg(feature = "avro")] pub mod avro; diff --git a/crates/polars-io/src/mmap.rs b/crates/polars-io/src/mmap.rs index 73dd48498f51..498c73da1a9d 100644 --- a/crates/polars-io/src/mmap.rs +++ b/crates/polars-io/src/mmap.rs @@ -1,77 +1,9 @@ -use std::collections::btree_map::Entry; -use std::collections::BTreeMap; use std::fs::File; use std::io::{BufReader, Cursor, Read, Seek}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; -use memmap::Mmap; -use once_cell::sync::Lazy; use polars_core::config::verbose; -use polars_error::{polars_bail, PolarsResult}; -use polars_utils::mmap::MemSlice; - -// Keep track of memory mapped files so we don't write to them while reading -// Use a btree as it uses less memory than a hashmap and this thing never shrinks. -// Write handle in Windows is exclusive, so this is only necessary in Unix. -#[cfg(target_family = "unix")] -static MEMORY_MAPPED_FILES: Lazy>> = - Lazy::new(|| Mutex::new(Default::default())); - -pub(crate) struct MMapSemaphore { - #[cfg(target_family = "unix")] - key: (u64, u64), - mmap: Mmap, -} - -impl MMapSemaphore { - #[cfg(target_family = "unix")] - pub(super) fn new(dev: u64, ino: u64, mmap: Mmap) -> Self { - let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); - let key = (dev, ino); - guard.insert(key, 1); - Self { key, mmap } - } - - #[cfg(not(target_family = "unix"))] - pub(super) fn new(mmap: Mmap) -> Self { - Self { mmap } - } -} - -impl AsRef<[u8]> for MMapSemaphore { - #[inline] - fn as_ref(&self) -> &[u8] { - self.mmap.as_ref() - } -} - -#[cfg(target_family = "unix")] -impl Drop for MMapSemaphore { - fn drop(&mut self) { - let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); - if let Entry::Occupied(mut e) = guard.entry(self.key) { - let v = e.get_mut(); - *v -= 1; - - if *v == 0 { - e.remove_entry(); - } - } - } -} - -pub fn ensure_not_mapped(file: &File) -> PolarsResult<()> { - #[cfg(target_family = "unix")] - { - use std::os::unix::fs::MetadataExt; - let guard = MEMORY_MAPPED_FILES.lock().unwrap(); - let metadata = file.metadata()?; - if guard.contains_key(&(metadata.dev(), metadata.ino())) { - polars_bail!(ComputeError: "cannot write to file: already memory mapped"); - } - } - Ok(()) -} +use polars_utils::mmap::{MMapSemaphore, MemSlice}; /// Trait used to get a hold to file handler or to the underlying bytes /// without performing a Read. @@ -97,6 +29,12 @@ impl MmapBytesReader for BufReader { } } +impl MmapBytesReader for BufReader<&File> { + fn to_file(&self) -> Option<&File> { + Some(self.get_ref()) + } +} + impl MmapBytesReader for Cursor where T: AsRef<[u8]> + Send + Sync, @@ -130,7 +68,7 @@ impl MmapBytesReader for &mut T { pub enum ReaderBytes<'a> { Borrowed(&'a [u8]), Owned(Vec), - Mapped(memmap::Mmap, &'a File), + Mapped(MMapSemaphore, &'a File), } impl std::ops::Deref for ReaderBytes<'_> { @@ -139,7 +77,7 @@ impl std::ops::Deref for ReaderBytes<'_> { match self { Self::Borrowed(ref_bytes) => ref_bytes, Self::Owned(vec) => vec, - Self::Mapped(mmap, _) => mmap, + Self::Mapped(mmap, _) => mmap.as_ref(), } } } @@ -167,7 +105,7 @@ impl<'a, T: 'a + MmapBytesReader> From<&'a mut T> for ReaderBytes<'a> { None => { if let Some(f) = m.to_file() { let f = unsafe { std::mem::transmute::<&File, &'a File>(f) }; - let mmap = unsafe { memmap::Mmap::map(f).unwrap() }; + let mmap = MMapSemaphore::new_from_file(f).unwrap(); ReaderBytes::Mapped(mmap, f) } else { if verbose() { diff --git a/crates/polars-io/src/ndjson/buffer.rs b/crates/polars-io/src/ndjson/buffer.rs index df526dc49ec4..2bb2a028f1ca 100644 --- a/crates/polars-io/src/ndjson/buffer.rs +++ b/crates/polars-io/src/ndjson/buffer.rs @@ -29,7 +29,7 @@ pub(crate) struct Buffer<'a> { impl Buffer<'_> { pub fn into_series(self) -> Series { let mut s = self.buf.into_series(); - s.rename(self.name); + s.rename(PlSmallStr::from_str(self.name)); s } @@ -201,7 +201,8 @@ fn deserialize_all<'a>( .iter() .map(|val| deserialize_all(val, inner_dtype, ignore_errors)) .collect::>()?; - let s = Series::from_any_values_and_dtype("", &vals, inner_dtype, false)?; + let s = + Series::from_any_values_and_dtype(PlSmallStr::EMPTY, &vals, inner_dtype, false)?; AnyValue::List(s) }, #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs index 2beb1f09d88b..c3754f9403d1 100644 --- a/crates/polars-io/src/ndjson/core.rs +++ b/crates/polars-io/src/ndjson/core.rs @@ -14,9 +14,8 @@ use crate::mmap::{MmapBytesReader, ReaderBytes}; use crate::ndjson::buffer::*; use crate::predicates::PhysicalIoExpr; use crate::prelude::*; -use crate::RowIndex; +use crate::{RowIndex, SerReader}; const NEWLINE: u8 = b'\n'; -const RETURN: u8 = b'\r'; const CLOSING_BRACKET: u8 = b'}'; #[must_use] @@ -37,7 +36,7 @@ where ignore_errors: bool, row_index: Option<&'a mut RowIndex>, predicate: Option>, - projection: Option>, + projection: Option>, } impl<'a, R> JsonLineReader<'a, R> @@ -68,7 +67,7 @@ where self } - pub fn with_projection(mut self, projection: Option>) -> Self { + pub fn with_projection(mut self, projection: Option>) -> Self { self.projection = projection; self } @@ -203,7 +202,7 @@ pub(crate) struct CoreJsonReader<'a> { ignore_errors: bool, row_index: Option<&'a mut RowIndex>, predicate: Option>, - projection: Option>, + projection: Option>, } impl<'a> CoreJsonReader<'a> { #[allow(clippy::too_many_arguments)] @@ -220,7 +219,7 @@ impl<'a> CoreJsonReader<'a> { ignore_errors: bool, row_index: Option<&'a mut RowIndex>, predicate: Option>, - projection: Option>, + projection: Option>, ) -> PolarsResult> { let reader_bytes = reader_bytes; @@ -259,8 +258,7 @@ impl<'a> CoreJsonReader<'a> { let iter = file_chunks.par_iter().map(|(start_pos, stop_at_nbytes)| { let bytes = &bytes[*start_pos..*stop_at_nbytes]; - let iter = serde_json::Deserializer::from_slice(bytes) - .into_iter::>(); + let iter = json_lines(bytes); iter.count() }); Ok(POOL.install(|| iter.sum())) @@ -316,13 +314,13 @@ impl<'a> CoreJsonReader<'a> { )?; let prepredicate_height = local_df.height() as IdxSize; - if let Some(projection) = &self.projection { - local_df = local_df.select(projection.as_ref())?; + if let Some(projection) = self.projection.as_deref() { + local_df = local_df.select(projection.iter().cloned())?; } if let Some(row_index) = row_index { local_df = local_df - .with_row_index(row_index.name.as_ref(), Some(row_index.offset))?; + .with_row_index(row_index.name.clone(), Some(row_index.offset))?; } if let Some(predicate) = &self.predicate { @@ -366,56 +364,55 @@ impl<'a> CoreJsonReader<'a> { fn parse_impl( bytes: &[u8], buffers: &mut PlIndexMap, - scratch: &mut Vec, + scratch: &mut Scratch, ) -> PolarsResult { - scratch.clear(); - scratch.extend_from_slice(bytes); - let n = scratch.len(); - let all_good = match n { - 0 => true, - 1 => scratch[0] == NEWLINE, - 2 => scratch[0] == NEWLINE && scratch[1] == RETURN, + scratch.json.clear(); + scratch.json.extend_from_slice(bytes); + let n = scratch.json.len(); + let value = simd_json::to_borrowed_value_with_buffers(&mut scratch.json, &mut scratch.buffers) + .map_err(|e| polars_err!(ComputeError: "error parsing line: {}", e))?; + match value { + simd_json::BorrowedValue::Object(value) => { + buffers.iter_mut().try_for_each(|(s, inner)| { + match s.0.map_lookup(&value) { + Some(v) => inner.add(v)?, + None => inner.add_null(), + } + PolarsResult::Ok(()) + })?; + }, _ => { - let value: simd_json::BorrowedValue = simd_json::to_borrowed_value(scratch) - .map_err(|e| polars_err!(ComputeError: "error parsing line: {}", e))?; - match value { - simd_json::BorrowedValue::Object(value) => { - buffers.iter_mut().try_for_each(|(s, inner)| { - match s.0.map_lookup(&value) { - Some(v) => inner.add(v)?, - None => inner.add_null(), - } - PolarsResult::Ok(()) - })?; - }, - _ => { - buffers.iter_mut().for_each(|(_, inner)| inner.add_null()); - }, - }; - true + buffers.iter_mut().for_each(|(_, inner)| inner.add_null()); }, }; - polars_ensure!(all_good, ComputeError: "invalid JSON: unexpected end of file"); Ok(n) } +#[derive(Default)] +struct Scratch { + json: Vec, + buffers: simd_json::Buffers, +} + +fn json_lines(bytes: &[u8]) -> impl Iterator { + // This previously used `serde_json`'s `RawValue` to deserialize chunks without really deserializing them. + // However, this convenience comes at a cost. serde_json allocates and parses and does UTF-8 validation, all + // things we don't need since we use simd_json for them. Also, `serde_json::StreamDeserializer` has a more + // ambitious goal: it wants to parse potentially *non-delimited* sequences of JSON values, while we know + // our values are line-delimited. Turns out, custom splitting is very easy, and gives a very nice performance boost. + bytes.split(|&byte| byte == b'\n').filter(|&bytes| { + bytes + .iter() + .any(|&byte| !matches!(byte, b' ' | b'\t' | b'\r')) + }) +} + fn parse_lines(bytes: &[u8], buffers: &mut PlIndexMap) -> PolarsResult<()> { - let mut buf = vec![]; - - // The `RawValue` is a pointer to the original JSON string and does not perform any deserialization. - // It is used to properly iterate over the lines without re-implementing the splitlines logic when this does the same thing. - let iter = - serde_json::Deserializer::from_slice(bytes).into_iter::>(); - for value_result in iter { - match value_result { - Ok(value) => { - let bytes = value.get().as_bytes(); - parse_impl(bytes, buffers, &mut buf)?; - }, - Err(e) => { - polars_bail!(ComputeError: "error parsing ndjson {}", e) - }, - } + let mut scratch = Scratch::default(); + + let iter = json_lines(bytes); + for bytes in iter { + parse_impl(bytes, buffers, &mut scratch)?; } Ok(()) } diff --git a/crates/polars-io/src/ndjson/mod.rs b/crates/polars-io/src/ndjson/mod.rs index 4ec6ffa7a1da..e19b857dbd35 100644 --- a/crates/polars-io/src/ndjson/mod.rs +++ b/crates/polars-io/src/ndjson/mod.rs @@ -10,11 +10,11 @@ pub fn infer_schema( reader: &mut R, infer_schema_len: Option, ) -> PolarsResult { - let data_types = polars_json::ndjson::iter_unique_dtypes(reader, infer_schema_len)?; - let data_type = - crate::json::infer::data_types_to_supertype(data_types.map(|dt| DataType::from(&dt)))?; - let schema = StructArray::get_fields(&data_type.to_arrow(CompatLevel::newest())) + let dtypes = polars_json::ndjson::iter_unique_dtypes(reader, infer_schema_len)?; + let dtype = crate::json::infer::dtypes_to_supertype(dtypes.map(|dt| DataType::from(&dt)))?; + let schema = StructArray::get_fields(&dtype.to_arrow(CompatLevel::newest())) .iter() + .map(Into::::into) .collect(); Ok(schema) } diff --git a/crates/polars-io/src/options.rs b/crates/polars-io/src/options.rs index 338bb819a099..4950b747d807 100644 --- a/crates/polars-io/src/options.rs +++ b/crates/polars-io/src/options.rs @@ -1,6 +1,5 @@ -use std::sync::Arc; - use polars_core::schema::SchemaRef; +use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -8,7 +7,7 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RowIndex { - pub name: Arc, + pub name: PlSmallStr, pub offset: IdxSize, } diff --git a/crates/polars-io/src/parquet/read/async_impl.rs b/crates/polars-io/src/parquet/read/async_impl.rs index 97e4829581bc..562156405b95 100644 --- a/crates/polars-io/src/parquet/read/async_impl.rs +++ b/crates/polars-io/src/parquet/read/async_impl.rs @@ -8,7 +8,7 @@ use polars_core::config::{get_rg_prefetch_size, verbose}; use polars_core::prelude::*; use polars_parquet::read::RowGroupMetaData; use polars_parquet::write::FileMetaData; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::Mutex; @@ -164,7 +164,7 @@ pub async fn fetch_metadata( /// Download rowgroups for the column whose indexes are given in `projection`. /// We concurrently download the columns for each field. async fn download_projection( - fields: Arc<[SmartString]>, + fields: Arc<[PlSmallStr]>, row_group: RowGroupMetaData, async_reader: Arc, sender: QueueSend, @@ -177,16 +177,12 @@ async fn download_projection( let mut ranges = Vec::with_capacity(fields.len()); let mut offsets = Vec::with_capacity(fields.len()); fields.iter().for_each(|name| { - let columns = row_group.columns(); - // A single column can have multiple matches (structs). - let iter = columns.iter().filter_map(|meta| { - if meta.descriptor().path_in_schema[0] == name.as_str() { - let (offset, len) = meta.byte_range(); - Some((offset, offset as usize..(offset + len) as usize)) - } else { - None - } + let iter = row_group.columns_under_root_iter(name).map(|meta| { + let byte_range = meta.byte_range(); + let offset = byte_range.start; + let byte_range = byte_range.start as usize..byte_range.end as usize; + (offset, byte_range) }); for (offset, range) in iter { @@ -214,33 +210,30 @@ async fn download_row_group( sender: QueueSend, rg_index: usize, ) -> bool { - if rg.columns().is_empty() { + if rg.n_columns() == 0 { return true; } - let offset = rg.columns().iter().map(|c| c.byte_range().0).min().unwrap(); - let (max_offset, len) = rg - .columns() - .iter() - .map(|c| c.byte_range()) - .max_by_key(|k| k.0) - .unwrap(); + + let full_byte_range = rg.full_byte_range(); + let full_byte_range = full_byte_range.start as usize..full_byte_range.end as usize; let result = async_reader - .get_range(offset as usize, (max_offset - offset + len) as usize) + .get_range( + full_byte_range.start, + full_byte_range.end - full_byte_range.start, + ) .await .map(|bytes| { - let base_offset = offset; ( rg_index, - rg.columns() - .iter() - .map(|c| { - let (offset, len) = c.byte_range(); - let slice_offset = offset - base_offset; - + rg.byte_ranges_iter() + .map(|range| { ( - offset, - bytes.slice(slice_offset as usize..(slice_offset + len) as usize), + range.start, + bytes.slice( + range.start as usize - full_byte_range.start + ..range.end as usize - full_byte_range.start, + ), ) }) .collect::(), @@ -264,10 +257,10 @@ impl FetchRowGroupsFromObjectStore { row_group_range: Range, row_groups: &[RowGroupMetaData], ) -> PolarsResult { - let projected_fields: Option> = projection.map(|projection| { + let projected_fields: Option> = projection.map(|projection| { projection .iter() - .map(|i| SmartString::from(schema.fields[*i].name.as_str())) + .map(|i| (schema.get_at_index(*i).as_ref().unwrap().0.clone())) .collect() }); @@ -277,6 +270,7 @@ impl FetchRowGroupsFromObjectStore { row_group_range .filter_map(|i| { let rg = &row_groups[i]; + let should_be_read = matches!(read_this_row_group(Some(pred), rg, &schema), Ok(true)); diff --git a/crates/polars-io/src/parquet/read/mmap.rs b/crates/polars-io/src/parquet/read/mmap.rs index 69ba42ac4c29..04edfc8400f4 100644 --- a/crates/polars-io/src/parquet/read/mmap.rs +++ b/crates/polars-io/src/parquet/read/mmap.rs @@ -1,3 +1,4 @@ +use arrow::array::Array; use arrow::datatypes::Field; #[cfg(feature = "async")] use bytes::Bytes; @@ -5,8 +6,7 @@ use bytes::Bytes; use polars_core::datatypes::PlHashMap; use polars_error::PolarsResult; use polars_parquet::read::{ - column_iter_to_arrays, get_field_columns, ArrayIter, BasicDecompressor, ColumnChunkMetaData, - Filter, PageReader, + column_iter_to_arrays, BasicDecompressor, ColumnChunkMetadata, Filter, PageReader, }; use polars_utils::mmap::{MemReader, MemSlice}; @@ -31,27 +31,29 @@ pub enum ColumnStore { /// For cloud files the relevant memory regions should have been prefetched. pub(super) fn mmap_columns<'a>( store: &'a ColumnStore, - columns: &'a [ColumnChunkMetaData], - field_name: &str, -) -> Vec<(&'a ColumnChunkMetaData, MemSlice)> { - get_field_columns(columns, field_name) - .into_iter() + field_columns: &'a [&ColumnChunkMetadata], +) -> Vec<(&'a ColumnChunkMetadata, MemSlice)> { + field_columns + .iter() .map(|meta| _mmap_single_column(store, meta)) .collect() } fn _mmap_single_column<'a>( store: &'a ColumnStore, - meta: &'a ColumnChunkMetaData, -) -> (&'a ColumnChunkMetaData, MemSlice) { - let (start, len) = meta.byte_range(); + meta: &'a ColumnChunkMetadata, +) -> (&'a ColumnChunkMetadata, MemSlice) { + let byte_range = meta.byte_range(); let chunk = match store { - ColumnStore::Local(mem_slice) => mem_slice.slice((start as usize)..(start + len) as usize), + ColumnStore::Local(mem_slice) => { + mem_slice.slice(byte_range.start as usize..byte_range.end as usize) + }, #[cfg(all(feature = "async", feature = "parquet"))] ColumnStore::Fetched(fetched) => { - let entry = fetched.get(&start).unwrap_or_else(|| { + let entry = fetched.get(&byte_range.start).unwrap_or_else(|| { panic!( - "mmap_columns: column with start {start} must be prefetched in ColumnStore.\n" + "mmap_columns: column with start {} must be prefetched in ColumnStore.\n", + byte_range.start ) }); MemSlice::from_bytes(entry.clone()) @@ -62,24 +64,18 @@ fn _mmap_single_column<'a>( // similar to arrow2 serializer, except this accepts a slice instead of a vec. // this allows us to memory map -pub(super) fn to_deserializer<'a>( - columns: Vec<(&ColumnChunkMetaData, MemSlice)>, +pub fn to_deserializer( + columns: Vec<(&ColumnChunkMetadata, MemSlice)>, field: Field, filter: Option, -) -> PolarsResult> { +) -> PolarsResult> { let (columns, types): (Vec<_>, Vec<_>) = columns .into_iter() .map(|(column_meta, chunk)| { // Advise fetching the data for the column chunk chunk.prefetch(); - let pages = PageReader::new( - MemReader::new(chunk), - column_meta, - std::sync::Arc::new(|_, _| true), - vec![], - usize::MAX, - ); + let pages = PageReader::new(MemReader::new(chunk), column_meta, vec![], usize::MAX); ( BasicDecompressor::new(pages, vec![]), &column_meta.descriptor().descriptor.primitive_type, diff --git a/crates/polars-io/src/parquet/read/mod.rs b/crates/polars-io/src/parquet/read/mod.rs index ea0549012dc5..14c24bce12ac 100644 --- a/crates/polars-io/src/parquet/read/mod.rs +++ b/crates/polars-io/src/parquet/read/mod.rs @@ -24,8 +24,21 @@ mod reader; mod to_metadata; mod utils; +const ROW_COUNT_OVERFLOW_ERR: PolarsError = PolarsError::ComputeError(ErrString::new_static( + "\ +Parquet file produces more than pow(2, 32) rows; \ +consider compiling with polars-bigidx feature (polars-u64-idx package on python), \ +or set 'streaming'", +)); + pub use options::{ParallelStrategy, ParquetOptions}; +use polars_error::{ErrString, PolarsError}; #[cfg(feature = "cloud")] pub use reader::ParquetAsyncReader; pub use reader::{BatchedParquetReader, ParquetReader}; pub use utils::materialize_empty_df; + +pub mod _internal { + pub use super::mmap::to_deserializer; + pub use super::predicates::read_this_row_group; +} diff --git a/crates/polars-io/src/parquet/read/predicates.rs b/crates/polars-io/src/parquet/read/predicates.rs index d3775864e1a3..fa1655be4846 100644 --- a/crates/polars-io/src/parquet/read/predicates.rs +++ b/crates/polars-io/src/parquet/read/predicates.rs @@ -1,4 +1,3 @@ -use arrow::datatypes::ArrowSchemaRef; use polars_core::prelude::*; use polars_parquet::read::statistics::{deserialize, Statistics}; use polars_parquet::read::RowGroupMetaData; @@ -9,40 +8,47 @@ impl ColumnStats { fn from_arrow_stats(stats: Statistics, field: &ArrowField) -> Self { Self::new( field.into(), - Some(Series::try_from(("", stats.null_count)).unwrap()), - Some(Series::try_from(("", stats.min_value)).unwrap()), - Some(Series::try_from(("", stats.max_value)).unwrap()), + Some(Series::try_from((PlSmallStr::EMPTY, stats.null_count)).unwrap()), + Some(Series::try_from((PlSmallStr::EMPTY, stats.min_value)).unwrap()), + Some(Series::try_from((PlSmallStr::EMPTY, stats.max_value)).unwrap()), ) } } -/// Collect the statistics in a column chunk. +/// Collect the statistics in a row-group pub(crate) fn collect_statistics( md: &RowGroupMetaData, schema: &ArrowSchema, ) -> PolarsResult> { - let mut stats = vec![]; + // TODO! fix this performance. This is a full sequential scan. + let stats = schema + .iter_values() + .map(|field| { + let iter = md.columns_under_root_iter(&field.name); - for field in schema.fields.iter() { - let st = deserialize(field, md)?; - stats.push(ColumnStats::from_arrow_stats(st, field)); + Ok(if iter.len() == 0 { + ColumnStats::new(field.into(), None, None, None) + } else { + ColumnStats::from_arrow_stats(deserialize(field, iter)?, field) + }) + }) + .collect::>>()?; + + if stats.is_empty() { + return Ok(None); } - Ok(if stats.is_empty() { - None - } else { - Some(BatchStats::new( - Arc::new(schema.into()), - stats, - Some(md.num_rows()), - )) - }) + Ok(Some(BatchStats::new( + Arc::new(Schema::from_arrow_schema(schema)), + stats, + Some(md.num_rows()), + ))) } -pub(super) fn read_this_row_group( +pub fn read_this_row_group( predicate: Option<&dyn PhysicalIoExpr>, md: &RowGroupMetaData, - schema: &ArrowSchemaRef, + schema: &ArrowSchema, ) -> PolarsResult { if let Some(pred) = predicate { if let Some(pred) = pred.as_stats_evaluator() { diff --git a/crates/polars-io/src/parquet/read/read_impl.rs b/crates/polars-io/src/parquet/read/read_impl.rs index 0fbb760ffb30..e43c34ca2d70 100644 --- a/crates/polars-io/src/parquet/read/read_impl.rs +++ b/crates/polars-io/src/parquet/read/read_impl.rs @@ -2,15 +2,19 @@ use std::borrow::Cow; use std::collections::VecDeque; use std::ops::{Deref, Range}; -use arrow::array::new_empty_array; -use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::array::BooleanArray; +use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowSchemaRef; +use polars_core::chunked_array::builder::NullChunkedBuilder; use polars_core::prelude::*; use polars_core::utils::{accumulate_dataframes_vertical, split_df}; use polars_core::POOL; -use polars_parquet::read::{self, ArrayIter, FileMetaData, Filter, PhysicalType, RowGroupMetaData}; +use polars_parquet::parquet::error::ParquetResult; +use polars_parquet::parquet::statistics::Statistics; +use polars_parquet::read::{ + self, ColumnChunkMetadata, FileMetaData, Filter, PhysicalType, RowGroupMetaData, +}; use polars_utils::mmap::MemSlice; -use polars_utils::vec::inplace_zip_filtermap; use rayon::prelude::*; #[cfg(feature = "cloud")] @@ -23,6 +27,7 @@ use super::{mmap, ParallelStrategy}; use crate::hive::materialize_hive_partitions; use crate::mmap::{MmapBytesReader, ReaderBytes}; use crate::parquet::metadata::FileMetaDataRef; +use crate::parquet::read::ROW_COUNT_OVERFLOW_ERR; use crate::predicates::{apply_predicate, PhysicalIoExpr}; use crate::utils::get_reader_bytes; use crate::utils::slice::split_slice_at_file; @@ -31,51 +36,64 @@ use crate::RowIndex; #[cfg(debug_assertions)] // Ensure we get the proper polars types from schema inference // This saves unneeded casts. -fn assert_dtypes(data_type: &ArrowDataType) { - match data_type { - ArrowDataType::Utf8 => { - unreachable!() - }, - ArrowDataType::Binary => { - unreachable!() - }, - ArrowDataType::List(_) => { - unreachable!() - }, - ArrowDataType::LargeList(inner) => { - assert_dtypes(&inner.data_type); - }, - ArrowDataType::Struct(fields) => { - for fld in fields { - assert_dtypes(fld.data_type()) - } - }, +fn assert_dtypes(dtype: &ArrowDataType) { + use ArrowDataType as D; + + match dtype { + // These should all be casted to the BinaryView / Utf8View variants + D::Utf8 | D::Binary | D::LargeUtf8 | D::LargeBinary => unreachable!(), + + // This should have been converted to a LargeList + D::List(_) => unreachable!(), + + // This should have been converted to a LargeList(Struct(_)) + D::Map(_, _) => unreachable!(), + + // Recursive checks + D::Dictionary(_, dtype, _) => assert_dtypes(dtype), + D::Extension(_, dtype, _) => assert_dtypes(dtype), + D::LargeList(inner) => assert_dtypes(&inner.dtype), + D::FixedSizeList(inner, _) => assert_dtypes(&inner.dtype), + D::Struct(fields) => fields.iter().for_each(|f| assert_dtypes(f.dtype())), + _ => {}, } } fn column_idx_to_series( column_i: usize, - md: &RowGroupMetaData, + // The metadata belonging to this column + field_md: &[&ColumnChunkMetadata], filter: Option, file_schema: &ArrowSchema, store: &mmap::ColumnStore, ) -> PolarsResult { - let field = &file_schema.fields[column_i]; + let field = file_schema.get_at_index(column_i).unwrap().1; #[cfg(debug_assertions)] { - assert_dtypes(field.data_type()) + assert_dtypes(field.dtype()) + } + let columns = mmap_columns(store, field_md); + let stats = columns + .iter() + .map(|(col_md, _)| col_md.statistics().transpose()) + .collect::>>>(); + let array = mmap::to_deserializer(columns, field.clone(), filter)?; + let mut series = Series::try_from((field, array))?; + + // We cannot really handle nested metadata at the moment. Just skip it. + use ArrowDataType as AD; + match field.dtype() { + AD::List(_) | AD::LargeList(_) | AD::Struct(_) | AD::FixedSizeList(_, _) => { + return Ok(series) + }, + _ => {}, } - - let columns = mmap_columns(store, md.columns(), &field.name); - let iter = mmap::to_deserializer(columns, field.clone(), filter)?; - - let mut series = array_iter_to_series(iter, field, None)?; // See if we can find some statistics for this series. If we cannot find anything just return // the series as is. - let Some(Ok(stats)) = md.columns()[column_i].statistics() else { + let Ok(Some(stats)) = stats.map(|mut s| s.pop().flatten()) else { return Ok(series); }; @@ -117,38 +135,6 @@ fn column_idx_to_series( Ok(series) } -pub(super) fn array_iter_to_series( - iter: ArrayIter, - field: &ArrowField, - num_rows: Option, -) -> PolarsResult { - let mut total_count = 0; - let chunks = match num_rows { - None => iter.collect::>>()?, - Some(n) => { - let mut out = Vec::with_capacity(2); - - for arr in iter { - let arr = arr?; - let len = arr.len(); - out.push(arr); - - total_count += len; - if total_count >= n { - break; - } - } - out - }, - }; - if chunks.is_empty() { - let arr = new_empty_array(field.data_type.clone()); - Series::try_from((field, arr)) - } else { - Series::try_from((field, chunks)) - } -} - #[allow(clippy::too_many_arguments)] fn rg_to_dfs( store: &mmap::ColumnStore, @@ -165,6 +151,20 @@ fn rg_to_dfs( use_statistics: bool, hive_partition_columns: Option<&[Series]>, ) -> PolarsResult> { + // If we are only interested in the row_index, we take a little special path here. + if projection.is_empty() { + if let Some(row_index) = row_index { + let placeholder = + NullChunkedBuilder::new(PlSmallStr::from_static("__PL_TMP"), slice.1).finish(); + return Ok(vec![DataFrame::new(vec![placeholder.into_series()])? + .with_row_index( + row_index.name.clone(), + Some(row_index.offset + IdxSize::try_from(slice.0).unwrap()), + )? + .select(std::iter::once(row_index.name))?]); + } + } + use ParallelStrategy as S; if parallel == S::Prefiltered { @@ -229,94 +229,107 @@ fn rg_to_dfs_prefiltered( row_group_end: usize, file_metadata: &FileMetaData, schema: &ArrowSchemaRef, - live_variables: Vec>, + live_variables: Vec, predicate: &dyn PhysicalIoExpr, row_index: Option, projection: &[usize], use_statistics: bool, hive_partition_columns: Option<&[Series]>, ) -> PolarsResult> { - struct RowGroupInfo { - index: u32, - row_offset: IdxSize, - } - if row_group_end > u32::MAX as usize { polars_bail!(ComputeError: "Parquet file contains too many row groups (> {})", u32::MAX); } let mut row_offset = *previous_row_count; - let mut row_groups: Vec = (row_group_start..row_group_end) - .filter_map(|index| { - let md = &file_metadata.row_groups[index]; - - let current_offset = row_offset; - let current_row_count = md.num_rows() as IdxSize; - row_offset += current_row_count; - - if use_statistics { - match read_this_row_group(Some(predicate), &file_metadata.row_groups[index], schema) - { - Ok(false) => return None, - Ok(true) => {}, - Err(e) => return Some(Err(e)), - } - } + let rg_offsets: Vec = match row_index { + None => Vec::new(), + Some(_) => (row_group_start..row_group_end) + .map(|index| { + let md = &file_metadata.row_groups[index]; - Some(Ok(RowGroupInfo { - index: index as u32, - row_offset: current_offset, - })) - }) - .collect::>>()?; + let current_offset = row_offset; + let current_row_count = md.num_rows() as IdxSize; + row_offset += current_row_count; - let num_live_columns = live_variables.len(); - let num_dead_columns = projection.len() - num_live_columns; + current_offset + }) + .collect(), + }; + // Deduplicate the live variables let live_variables = live_variables .iter() .map(Deref::deref) .collect::>(); + // Get the number of live columns + let num_live_columns = live_variables.len(); + let num_dead_columns = projection.len() - num_live_columns; + // We create two look-up tables that map indexes offsets into the live- and dead-set onto // column indexes of the schema. let mut live_idx_to_col_idx = Vec::with_capacity(num_live_columns); let mut dead_idx_to_col_idx = Vec::with_capacity(num_dead_columns); - for (i, col) in file_metadata.schema().columns().iter().enumerate() { - if live_variables.contains(col.path_in_schema[0].deref()) { + for (i, field) in schema.iter_values().enumerate() { + if live_variables.contains(&field.name[..]) { live_idx_to_col_idx.push(i); } else { dead_idx_to_col_idx.push(i); } } - debug_assert_eq!(live_variables.len(), num_live_columns); + + debug_assert_eq!(live_idx_to_col_idx.len(), num_live_columns); debug_assert_eq!(dead_idx_to_col_idx.len(), num_dead_columns); - POOL.install(|| { - // Collect the data for the live columns - let mut live_columns = (0..row_groups.len() * num_live_columns) - .into_par_iter() - .map(|i| { - let col_idx = live_idx_to_col_idx[i % num_live_columns]; - let rg_idx = row_groups[i / num_live_columns].index as usize; + enum MaskSetting { + Auto, + Pre, + Post, + } + let mask_setting = + std::env::var("POLARS_PQ_PREFILTERED_MASK").map_or(MaskSetting::Auto, |v| match &v[..] { + "auto" => MaskSetting::Auto, + "pre" => MaskSetting::Pre, + "post" => MaskSetting::Post, + _ => panic!("Invalid `POLARS_PQ_PREFILTERED_MASK` value '{v}'."), + }); + + let dfs: Vec> = POOL.install(|| { + // Set partitioned fields to prevent quadratic behavior. + // Ensure all row groups are partitioned. + + (row_group_start..row_group_end) + .into_par_iter() + .map(|rg_idx| { let md = &file_metadata.row_groups[rg_idx]; - column_idx_to_series(col_idx, md, None, schema, store) - }) - .collect::>>()?; - // Apply the predicate to the live columns and save the dataframe and the bitmask - let mut dfs = live_columns - .par_chunks_exact_mut(num_live_columns) - .enumerate() - .map(|(i, columns)| { - let rg = &row_groups[i]; - let rg_idx = rg.index as usize; + if use_statistics { + match read_this_row_group(Some(predicate), md, schema) { + Ok(false) => return Ok(None), + Ok(true) => {}, + Err(e) => return Err(e), + } + } + + // Collect the data for the live columns + let live_columns = (0..num_live_columns) + .into_par_iter() + .map(|i| { + let col_idx = live_idx_to_col_idx[i]; - let columns = columns.iter_mut().map(std::mem::take).collect::>(); + let name = schema.get_at_index(col_idx).unwrap().0; + let field_md = file_metadata.row_groups[rg_idx] + .columns_under_root_iter(name) + .collect::>(); + + column_idx_to_series(col_idx, field_md.as_slice(), None, schema, store) + }) + .collect::>>()?; + // Apply the predicate to the live columns and save the dataframe and the bitmask let md = &file_metadata.row_groups[rg_idx]; - let mut df = unsafe { DataFrame::new_no_checks(columns) }; + let mut df = unsafe { DataFrame::new_no_checks(live_columns) }; materialize_hive_partitions( &mut df, @@ -328,119 +341,156 @@ fn rg_to_dfs_prefiltered( let mask = s.bool().expect("filter predicates was not of type boolean"); if let Some(rc) = &row_index { - df.with_row_index_mut(&rc.name, Some(rg.row_offset + rc.offset)); + df.with_row_index_mut(rc.name.clone(), Some(rg_offsets[rg_idx] + rc.offset)); } df = df.filter(mask)?; - let mut bitmap = MutableBitmap::with_capacity(mask.len()); + let mut filter_mask = MutableBitmap::with_capacity(mask.len()); + // We need to account for the validity of the items for chunk in mask.downcast_iter() { - bitmap.extend_from_bitmap(chunk.values()); + match chunk.validity() { + None => filter_mask.extend_from_bitmap(chunk.values()), + Some(validity) => { + filter_mask.extend_from_bitmap(&(validity & chunk.values())) + }, + } } - let bitmap = bitmap.freeze(); - - debug_assert_eq!(md.num_rows(), bitmap.len()); - debug_assert_eq!(df.height(), bitmap.set_bits()); + let filter_mask = filter_mask.freeze(); - Ok((bitmap, df)) - }) - .collect::>>()?; + debug_assert_eq!(md.num_rows(), filter_mask.len()); + debug_assert_eq!(df.height(), filter_mask.set_bits()); - // Filter out the row-groups that do not include any rows that match the predicate. - inplace_zip_filtermap(&mut dfs, &mut row_groups, |(mask, df), rg| { - (mask.set_bits() > 0).then_some(((mask, df), rg)) - }); + if filter_mask.set_bits() == 0 { + return Ok(None); + } - for (_, df) in &dfs { - *previous_row_count += df.height() as IdxSize; - } + // We don't need to do any further work if there are no dead columns + if num_dead_columns == 0 { + return Ok(Some(df)); + } - // @TODO: Incorporate this if we how we can properly use it. The problem here is that - // different columns really have a different cost when it comes to collecting them. We - // would need a cost model to properly estimate this. - // - // // For bitmasks that are seemingly random (i.e. not clustered or biased towards 0 or 1), - // // filtering with a bitmask in the Parquet reader is actually around 1.5 - 2.2 times slower - // // than collecting everything and filtering afterwards. This is because stopping and - // // starting decoding is not free. - // // - // // To combat this we try to detect here how biased our data is. We do this with a bithack - // // that estimates the amount of switches from 0 to 1 and from 1 to 0. This can be SIMD-ed - // // very well and gives us quite good estimate of how random our bitmask is. Then, we select - // // the filter if the bitmask is not that random. - // let do_filter_rg = dfs - // .par_iter() - // .map(|(mask, _)| { - // let iter = mask.fast_iter_u64(); - // - // // The iter is TrustedLen so the size_hint is exact. - // let num_items = iter.size_hint().0; - // let num_switches = iter - // .map(|v| (v ^ v.rotate_right(1)).count_ones() as u64) - // .sum::(); - // - // // We ignore the iter remainder since we only really care about the average. - // let avg_num_switches_per_element = num_switches / num_items as u64; - // - // // We select the filter if the average amount of switches per 64 elements is less - // // than or equal to 2. - // avg_num_switches_per_element <= 2 - // }) - // .collect::>(); - - let mut rg_columns = (0..dfs.len() * num_dead_columns) - .into_par_iter() - .map(|i| { - let col_idx = dead_idx_to_col_idx[i % num_dead_columns]; - let rg_idx = row_groups[i / num_dead_columns].index as usize; + let prefilter_cost = matches!(mask_setting, MaskSetting::Auto) + .then(|| { + let num_edges = filter_mask.num_edges() as f64; + let rg_len = filter_mask.len() as f64; + + // @GB: I did quite some analysis on this. + // + // Pre-filtered and Post-filtered can both be faster in certain scenarios. + // + // - Pre-filtered is faster when there is some amount of clustering or + // sorting involved or if the number of values selected is small. + // - Post-filtering is faster when the predicate selects a somewhat random + // elements throughout the row group. + // + // The following is a heuristic value to try and estimate which one is + // faster. Essentially, it sees how many times it needs to switch between + // skipping items and collecting items and compares it against the number + // of values that it will collect. + // + // Closer to 0: pre-filtering is probably better. + // Closer to 1: post-filtering is probably better. + (num_edges / rg_len).clamp(0.0, 1.0) + }) + .unwrap_or_default(); + + let rg_columns = (0..num_dead_columns) + .into_par_iter() + .map(|i| { + let col_idx = dead_idx_to_col_idx[i]; + let name = schema.get_at_index(col_idx).unwrap().0; + + #[cfg(debug_assertions)] + { + let md = &file_metadata.row_groups[rg_idx]; + debug_assert_eq!(md.num_rows(), mask.len()); + } + let field_md = file_metadata.row_groups[rg_idx] + .columns_under_root_iter(name) + .collect::>(); + + let pre = || { + column_idx_to_series( + col_idx, + field_md.as_slice(), + Some(Filter::new_masked(filter_mask.clone())), + schema, + store, + ) + }; + let post = || { + let array = column_idx_to_series( + col_idx, + field_md.as_slice(), + None, + schema, + store, + )?; + + debug_assert_eq!(array.len(), mask.len()); + + let mask_arr = BooleanArray::new( + ArrowDataType::Boolean, + filter_mask.clone(), + None, + ); + let mask_arr = BooleanChunked::from(mask_arr); + array.filter(&mask_arr) + }; + + let array = match mask_setting { + MaskSetting::Auto => { + // Prefiltering is more expensive for nested types so we make the cut-off + // higher. + let is_nested = + schema.get_at_index(col_idx).unwrap().1.dtype.is_nested(); + + // We empirically selected these numbers. + let do_prefilter = (is_nested && prefilter_cost <= 0.01) + || (!is_nested && prefilter_cost <= 0.02); + + if do_prefilter { + pre()? + } else { + post()? + } + }, + MaskSetting::Pre => pre()?, + MaskSetting::Post => post()?, + }; - let (mask, _) = &dfs[i / num_dead_columns]; + debug_assert_eq!(array.len(), filter_mask.set_bits()); - let md = &file_metadata.row_groups[rg_idx]; - debug_assert_eq!(md.num_rows(), mask.len()); - column_idx_to_series( - col_idx, - md, - Some(Filter::new_masked(mask.clone())), - schema, - store, - ) - }) - .collect::>>()?; + Ok(array) + }) + .collect::>>()?; - let mut rearranged_schema: Schema = Schema::new(); - if let Some(rc) = &row_index { - rearranged_schema.insert_at_index( - 0, - SmartString::from(rc.name.deref()), - IdxType::get_dtype(), - )?; - } - for i in live_idx_to_col_idx.iter().copied() { - rearranged_schema.insert_at_index( - rearranged_schema.len(), - schema.fields[i].name.clone().into(), - schema.fields[i].data_type().into(), - )?; - } - rearranged_schema.merge(Schema::from(schema.as_ref())); + let mut rearranged_schema = df.schema(); + rearranged_schema.merge(Schema::from_arrow_schema(schema.as_ref())); - rg_columns - .par_chunks_exact_mut(num_dead_columns) - .zip(dfs) - .map(|(rg_cols, (_, mut df))| { - let rg_cols = rg_cols.iter_mut().map(std::mem::take).collect::>(); + debug_assert!(rg_columns.iter().all(|v| v.len() == df.height())); // We first add the columns with the live columns at the start. Then, we do a // projections that puts the columns at the right spot. - df._add_columns(rg_cols, &rearranged_schema)?; - let df = df.select(schema.get_names())?; + df._add_columns(rg_columns, &rearranged_schema)?; + let df = df.select(schema.iter_names_cloned())?; - PolarsResult::Ok(df) + PolarsResult::Ok(Some(df)) }) - .collect::>>() - }) + .collect::>>>() + })?; + + let dfs: Vec = dfs.into_iter().flatten().collect(); + + let row_count: usize = dfs.iter().map(|df| df.height()).sum(); + let row_count = IdxSize::try_from(row_count).map_err(|_| ROW_COUNT_OVERFLOW_ERR)?; + *previous_row_count = previous_row_count + .checked_add(row_count) + .ok_or(ROW_COUNT_OVERFLOW_ERR)?; + + Ok(dfs) } #[allow(clippy::too_many_arguments)] @@ -469,6 +519,7 @@ fn rg_to_dfs_optionally_par_over_columns( for rg_idx in row_group_start..row_group_end { let md = &file_metadata.row_groups[rg_idx]; + let rg_slice = split_slice_at_file(&mut n_rows_processed, md.num_rows(), slice.0, slice_end); let current_row_count = md.num_rows() as IdxSize; @@ -490,9 +541,12 @@ fn rg_to_dfs_optionally_par_over_columns( projection .par_iter() .map(|column_i| { + let name = schema.get_at_index(*column_i).unwrap().0; + let part = md.columns_under_root_iter(name).collect::>(); + column_idx_to_series( *column_i, - md, + part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, @@ -504,9 +558,12 @@ fn rg_to_dfs_optionally_par_over_columns( projection .iter() .map(|column_i| { + let name = schema.get_at_index(*column_i).unwrap().0; + let part = md.columns_under_root_iter(name).collect::>(); + column_idx_to_series( *column_i, - md, + part.as_slice(), Some(Filter::new_ranged(rg_slice.0, rg_slice.0 + rg_slice.1)), schema, store, @@ -517,13 +574,19 @@ fn rg_to_dfs_optionally_par_over_columns( let mut df = unsafe { DataFrame::new_no_checks(columns) }; if let Some(rc) = &row_index { - df.with_row_index_mut(&rc.name, Some(*previous_row_count + rc.offset)); + df.with_row_index_mut(rc.name.clone(), Some(*previous_row_count + rc.offset)); } materialize_hive_partitions(&mut df, schema.as_ref(), hive_partition_columns, rg_slice.1); apply_predicate(&mut df, predicate, true)?; - *previous_row_count += current_row_count; + *previous_row_count = previous_row_count.checked_add(current_row_count).ok_or_else(|| + polars_err!( + ComputeError: "Parquet file produces more than pow(2, 32) rows; \ + consider compiling with polars-bigidx feature (polars-u64-idx package on python), \ + or set 'streaming'" + ), + )?; dfs.push(df); if *previous_row_count as usize >= slice_end { @@ -563,7 +626,9 @@ fn rg_to_dfs_par_over_rg( let rg_md = &file_metadata.row_groups[i]; let rg_slice = split_slice_at_file(&mut n_rows_processed, rg_md.num_rows(), slice.0, slice_end); - *previous_row_count += rg_slice.1 as IdxSize; + *previous_row_count = previous_row_count + .checked_add(rg_slice.1 as IdxSize) + .ok_or(ROW_COUNT_OVERFLOW_ERR)?; if rg_slice.1 == 0 { continue; @@ -573,17 +638,15 @@ fn rg_to_dfs_par_over_rg( } let dfs = POOL.install(|| { + // Set partitioned fields to prevent quadratic behavior. + // Ensure all row groups are partitioned. row_groups .into_par_iter() - .map(|(rg_idx, md, slice, row_count_start)| { - if slice.1 == 0 - || use_statistics - && !read_this_row_group( - predicate, - &file_metadata.row_groups[rg_idx], - schema, - )? - { + .enumerate() + .map(|(iter_idx, (_rg_idx, _md, slice, row_count_start))| { + let md = &file_metadata.row_groups[iter_idx]; + + if slice.1 == 0 || use_statistics && !read_this_row_group(predicate, md, schema)? { return Ok(None); } // test we don't read the parquet file if this env var is set @@ -595,9 +658,12 @@ fn rg_to_dfs_par_over_rg( let columns = projection .iter() .map(|column_i| { + let name = schema.get_at_index(*column_i).unwrap().0; + let field_md = md.columns_under_root_iter(name).collect::>(); + column_idx_to_series( *column_i, - md, + field_md.as_slice(), Some(Filter::new_ranged(slice.0, slice.0 + slice.1)), schema, store, @@ -608,7 +674,10 @@ fn rg_to_dfs_par_over_rg( let mut df = unsafe { DataFrame::new_no_checks(columns) }; if let Some(rc) = &row_index { - df.with_row_index_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset)); + df.with_row_index_mut( + rc.name.clone(), + Some(row_count_start as IdxSize + rc.offset), + ); } materialize_hive_partitions( @@ -675,7 +744,13 @@ pub fn read_parquet( .unwrap_or_else(|| Cow::Owned((0usize..reader_schema.len()).collect::>())); if let ParallelStrategy::Auto = parallel { - if n_row_groups > materialized_projection.len() || n_row_groups > POOL.current_num_threads() + if predicate.is_some_and(|predicate| { + predicate.live_variables().map_or(0, |v| v.len()) * n_row_groups + >= POOL.current_num_threads() + }) { + parallel = ParallelStrategy::Prefiltered; + } else if n_row_groups > materialized_projection.len() + || n_row_groups > POOL.current_num_threads() { parallel = ParallelStrategy::RowGroups; } else { @@ -855,7 +930,7 @@ impl BatchedParquetReader { chunk_size: usize, use_statistics: bool, hive_partition_columns: Option>, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, mut parallel: ParallelStrategy, ) -> PolarsResult { let n_row_groups = metadata.row_groups.len(); @@ -895,7 +970,7 @@ impl BatchedParquetReader { use_statistics, hive_partition_columns: hive_partition_columns.map(Arc::from), include_file_path: include_file_path - .map(|(col, path)| StringChunked::full(&col, &path, 1)), + .map(|(col, path)| StringChunked::full(col, &path, 1)), has_returned: false, }) } @@ -1029,7 +1104,7 @@ impl BatchedParquetReader { // Re-use the same ChunkedArray if ca.len() < max_len { - *ca = ca.new_from_index(max_len, 0); + *ca = ca.new_from_index(0, max_len); } for df in &mut dfs { diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index 30eb593191eb..25e8852a92ce 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -38,7 +38,7 @@ pub struct ParquetReader { metadata: Option, predicate: Option>, hive_partition_columns: Option>, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, use_statistics: bool, } @@ -87,12 +87,10 @@ impl ParquetReader { let self_schema = self.schema()?; let self_schema = self_schema.as_ref(); - if let Some(ref projection) = self.projection { - let projection = projection.as_slice(); - + if let Some(projection) = self.projection.as_deref() { ensure_matching_schema( - &schema.try_project(projection)?, - &self_schema.try_project(projection)?, + &schema.try_project_indices(projection)?, + &self_schema.try_project_indices(projection)?, )?; } else { ensure_matching_schema(schema, self_schema)?; @@ -134,7 +132,7 @@ impl ParquetReader { pub fn with_include_file_path( mut self, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, ) -> Self { self.include_file_path = include_file_path; self @@ -234,7 +232,7 @@ impl SerReader for ParquetReader { unsafe { df.with_column_unchecked( StringChunked::full( - col, + col.clone(), value, if df.width() > 0 { df.height() } else { n_rows }, ) @@ -259,7 +257,7 @@ pub struct ParquetAsyncReader { row_index: Option, use_statistics: bool, hive_partition_columns: Option>, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, schema: Option, parallel: ParallelStrategy, } @@ -290,12 +288,10 @@ impl ParquetAsyncReader { let self_schema = self.schema().await?; let self_schema = self_schema.as_ref(); - if let Some(ref projection) = self.projection { - let projection = projection.as_slice(); - + if let Some(projection) = self.projection.as_deref() { ensure_matching_schema( - &schema.try_project(projection)?, - &self_schema.try_project(projection)?, + &schema.try_project_indices(projection)?, + &self_schema.try_project_indices(projection)?, )?; } else { ensure_matching_schema(schema, self_schema)?; @@ -362,7 +358,7 @@ impl ParquetAsyncReader { pub fn with_include_file_path( mut self, - include_file_path: Option<(Arc, Arc)>, + include_file_path: Option<(PlSmallStr, Arc)>, ) -> Self { self.include_file_path = include_file_path; self diff --git a/crates/polars-io/src/parquet/read/utils.rs b/crates/polars-io/src/parquet/read/utils.rs index bb476a5fad08..34cc752dd782 100644 --- a/crates/polars-io/src/parquet/read/utils.rs +++ b/crates/polars-io/src/parquet/read/utils.rs @@ -20,7 +20,7 @@ pub fn materialize_empty_df( let mut df = DataFrame::empty_with_arrow_schema(&schema); if let Some(row_index) = row_index { - df.insert_column(0, Series::new_empty(&row_index.name, &IDX_DTYPE)) + df.insert_column(0, Series::new_empty(row_index.name.clone(), &IDX_DTYPE)) .unwrap(); } diff --git a/crates/polars-io/src/parquet/write/writer.rs b/crates/polars-io/src/parquet/write/writer.rs index 3129421e21d7..99cf4c95a45b 100644 --- a/crates/polars-io/src/parquet/write/writer.rs +++ b/crates/polars-io/src/parquet/write/writer.rs @@ -132,15 +132,14 @@ where fn get_encodings(schema: &ArrowSchema) -> Vec> { schema - .fields - .iter() - .map(|f| transverse(&f.data_type, encoding_map)) + .iter_values() + .map(|f| transverse(&f.dtype, encoding_map)) .collect() } /// Declare encodings -fn encoding_map(data_type: &ArrowDataType) -> Encoding { - match data_type.to_physical_type() { +fn encoding_map(dtype: &ArrowDataType) -> Encoding { + match dtype.to_physical_type() { PhysicalType::Dictionary(_) | PhysicalType::LargeBinary | PhysicalType::LargeUtf8 diff --git a/crates/polars-io/src/partition.rs b/crates/polars-io/src/partition.rs index 98508cc14e5a..bea0b9958a63 100644 --- a/crates/polars-io/src/partition.rs +++ b/crates/polars-io/src/partition.rs @@ -28,18 +28,20 @@ impl WriteDataFrameToFile for IpcWriterOptions { } } -/// Write a partitioned parquet dataset. This functionality is unstable. -pub fn write_partitioned_dataset( +fn write_partitioned_dataset_impl( df: &mut DataFrame, path: &Path, - partition_by: &[S], - file_write_options: &O, + partition_by: Vec, + file_write_options: &W, chunk_size: usize, ) -> PolarsResult<()> where - S: AsRef, - O: WriteDataFrameToFile + Send + Sync, + W: WriteDataFrameToFile + Send + Sync, { + let partition_by = partition_by + .into_iter() + .map(Into::into) + .collect::>(); // Ensure we have a single chunk as the gather will otherwise rechunk per group. df.as_single_chunk_par(); @@ -52,8 +54,8 @@ where let partition_by_col_idx = partition_by .iter() .map(|x| { - let Some(i) = schema.index_of(x.as_ref()) else { - polars_bail!(col_not_found = x.as_ref()) + let Some(i) = schema.index_of(x.as_str()) else { + polars_bail!(col_not_found = x) }; Ok(i) }) @@ -184,3 +186,23 @@ where Ok(()) } + +/// Write a partitioned parquet dataset. This functionality is unstable. +pub fn write_partitioned_dataset( + df: &mut DataFrame, + path: &Path, + partition_by: I, + file_write_options: &W, + chunk_size: usize, +) -> PolarsResult<()> +where + I: IntoIterator, + S: Into, + W: WriteDataFrameToFile + Send + Sync, +{ + let partition_by = partition_by + .into_iter() + .map(Into::into) + .collect::>(); + write_partitioned_dataset_impl(df, path, partition_by, file_write_options, chunk_size) +} diff --git a/crates/polars-io/src/path_utils/mod.rs b/crates/polars-io/src/path_utils/mod.rs index 5c4e48f7e6e4..1795cda6ebd0 100644 --- a/crates/polars-io/src/path_utils/mod.rs +++ b/crates/polars-io/src/path_utils/mod.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use once_cell::sync::Lazy; use polars_core::config; use polars_core::error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use regex::Regex; #[cfg(feature = "cloud")] @@ -88,7 +89,7 @@ pub fn expand_paths( paths: &[PathBuf], glob: bool, #[allow(unused_variables)] cloud_options: Option<&CloudOptions>, -) -> PolarsResult>> { +) -> PolarsResult> { expand_paths_hive(paths, glob, cloud_options, false).map(|x| x.0) } @@ -129,13 +130,69 @@ pub fn expand_paths_hive( glob: bool, #[allow(unused_variables)] cloud_options: Option<&CloudOptions>, check_directory_level: bool, -) -> PolarsResult<(Arc>, usize)> { +) -> PolarsResult<(Arc<[PathBuf]>, usize)> { let Some(first_path) = paths.first() else { return Ok((vec![].into(), 0)); }; let is_cloud = is_cloud_url(first_path); - let mut out_paths = vec![]; + + /// Wrapper around `Vec` that also tracks file extensions, so that + /// we don't have to traverse the entire list again to validate extensions. + struct OutPaths { + paths: Vec, + exts: [Option<(PlSmallStr, usize)>; 2], + current_idx: usize, + } + + impl OutPaths { + fn update_ext_status( + current_idx: &mut usize, + exts: &mut [Option<(PlSmallStr, usize)>; 2], + value: &Path, + ) { + let ext = value + .extension() + .map(|x| PlSmallStr::from(x.to_str().unwrap())) + .unwrap_or(PlSmallStr::EMPTY); + + if exts[0].is_none() { + exts[0] = Some((ext, *current_idx)); + } else if exts[1].is_none() && ext != exts[0].as_ref().unwrap().0 { + exts[1] = Some((ext, *current_idx)); + } + + *current_idx += 1; + } + + fn push(&mut self, value: PathBuf) { + { + let current_idx = &mut self.current_idx; + let exts = &mut self.exts; + Self::update_ext_status(current_idx, exts, &value); + } + self.paths.push(value) + } + + fn extend(&mut self, values: impl IntoIterator) { + let current_idx = &mut self.current_idx; + let exts = &mut self.exts; + + self.paths.extend(values.into_iter().inspect(|x| { + Self::update_ext_status(current_idx, exts, x); + })) + } + + fn extend_from_slice(&mut self, values: &[PathBuf]) { + self.extend(values.iter().cloned()) + } + } + + let mut out_paths = OutPaths { + paths: vec![], + exts: [None, None], + current_idx: 0, + }; let mut hive_idx_tracker = HiveIdxTracker { idx: usize::MAX, @@ -337,31 +394,20 @@ pub fn expand_paths_hive( } } - let out_paths = if expanded_from_single_directory(paths, out_paths.as_ref()) { - // Require all file extensions to be the same when expanding a single directory. - let ext = out_paths[0].extension(); - - (0..out_paths.len()) - .map(|i| { - let path = out_paths[i].clone(); - - if path.extension() != ext { - polars_bail!( - InvalidOperation: r#"directory contained paths with different file extensions: \ - first path: {}, second path: {}. Please use a glob pattern to explicitly specify \ - which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#, - out_paths[i - 1].to_str().unwrap(), path.to_str().unwrap() - ); - }; + assert_eq!(out_paths.current_idx, out_paths.paths.len()); - Ok(path) - }) - .collect::>>()? - } else { - out_paths - }; + if expanded_from_single_directory(paths, out_paths.paths.as_slice()) { + if let [Some((_, i1)), Some((_, i2))] = out_paths.exts { + polars_bail!( + InvalidOperation: r#"directory contained paths with different file extensions: \ + first path: {}, second path: {}. Please use a glob pattern to explicitly specify \ + which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#, + &out_paths.paths[i1].to_string_lossy(), &out_paths.paths[i2].to_string_lossy() + ) + } + } - Ok((Arc::new(out_paths), hive_idx_tracker.idx)) + Ok((out_paths.paths.into(), hive_idx_tracker.idx)) } /// Ignores errors from `std::fs::create_dir_all` if the directory exists. diff --git a/crates/polars-io/src/pl_async.rs b/crates/polars-io/src/pl_async.rs index b9731751483d..cc43a908cda3 100644 --- a/crates/polars-io/src/pl_async.rs +++ b/crates/polars-io/src/pl_async.rs @@ -237,8 +237,16 @@ pub struct RuntimeManager { impl RuntimeManager { fn new() -> Self { + let n_threads = std::env::var("POLARS_ASYNC_THREAD_COUNT") + .map(|x| x.parse::().expect("integer")) + .unwrap_or((POOL.current_num_threads() / 4).clamp(1, 4)); + + if polars_core::config::verbose() { + eprintln!("Async thread count: {}", n_threads); + } + let rt = Builder::new_multi_thread() - .worker_threads(std::cmp::max(POOL.current_num_threads(), 4)) + .worker_threads(n_threads) .enable_io() .enable_time() .build() diff --git a/crates/polars-io/src/predicates.rs b/crates/polars-io/src/predicates.rs index 08ad7685461c..b46600666c44 100644 --- a/crates/polars-io/src/predicates.rs +++ b/crates/polars-io/src/predicates.rs @@ -8,7 +8,7 @@ pub trait PhysicalIoExpr: Send + Sync { fn evaluate_io(&self, df: &DataFrame) -> PolarsResult; /// Get the variables that are used in the expression i.e. live variables. - fn live_variables(&self) -> Option>>; + fn live_variables(&self) -> Option>; /// Can take &dyn Statistics and determine of a file should be /// read -> `true` @@ -94,13 +94,13 @@ impl ColumnStats { } } - pub fn field_name(&self) -> &SmartString { + pub fn field_name(&self) -> &PlSmallStr { self.field.name() } /// Returns the [`DataType`] of the column. pub fn dtype(&self) -> &DataType { - self.field.data_type() + self.field.dtype() } /// Returns the null count of each row group of the column. diff --git a/crates/polars-io/src/shared.rs b/crates/polars-io/src/shared.rs index 735490b0bcb3..7fbb5eb96e7f 100644 --- a/crates/polars-io/src/shared.rs +++ b/crates/polars-io/src/shared.rs @@ -65,10 +65,10 @@ pub(crate) fn finish_reader( while let Some(batch) = reader.next_record_batch()? { let current_num_rows = num_rows as IdxSize; num_rows += batch.len(); - let mut df = DataFrame::try_from((batch, arrow_schema.fields.as_slice()))?; + let mut df = DataFrame::try_from((batch, arrow_schema))?; if let Some(rc) = &row_index { - df.with_row_index_mut(&rc.name, Some(current_num_rows + rc.offset)); + df.with_row_index_mut(rc.name.clone(), Some(current_num_rows + rc.offset)); } if let Some(predicate) = &predicate { @@ -97,11 +97,8 @@ pub(crate) fn finish_reader( if parsed_dfs.is_empty() { // Create an empty dataframe with the correct data types let empty_cols = arrow_schema - .fields - .iter() - .map(|fld| { - Series::try_from((fld.name.as_str(), new_empty_array(fld.data_type.clone()))) - }) + .iter_values() + .map(|fld| Series::try_from((fld.name.clone(), new_empty_array(fld.dtype.clone())))) .collect::>()?; DataFrame::new(empty_cols)? } else { @@ -121,10 +118,22 @@ pub(crate) fn schema_to_arrow_checked( compat_level: CompatLevel, _file_name: &str, ) -> PolarsResult { - let fields = schema.iter_fields().map(|field| { - #[cfg(feature = "object")] - polars_ensure!(!matches!(field.data_type(), DataType::Object(_, _)), ComputeError: "cannot write 'Object' datatype to {}", _file_name); - Ok(field.data_type().to_arrow_field(field.name().as_str(), compat_level)) - }).collect::>>()?; - Ok(ArrowSchema::from(fields)) + schema + .iter_fields() + .map(|field| { + #[cfg(feature = "object")] + { + polars_ensure!( + !matches!(field.dtype(), DataType::Object(_, _)), + ComputeError: "cannot write 'Object' datatype to {}", + _file_name + ); + } + + let field = field + .dtype() + .to_arrow_field(field.name().clone(), compat_level); + Ok((field.name.clone(), field)) + }) + .collect::>() } diff --git a/crates/polars-io/src/utils/byte_source.rs b/crates/polars-io/src/utils/byte_source.rs new file mode 100644 index 000000000000..72cbabb3dd5c --- /dev/null +++ b/crates/polars-io/src/utils/byte_source.rs @@ -0,0 +1,175 @@ +use std::ops::Range; +use std::sync::Arc; + +use polars_error::PolarsResult; +use polars_utils::_limit_path_len_io_err; +use polars_utils::mmap::MemSlice; + +use crate::cloud::{ + build_object_store, object_path_from_str, CloudLocation, CloudOptions, ObjectStorePath, + PolarsObjectStore, +}; + +#[allow(async_fn_in_trait)] +pub trait ByteSource: Send + Sync { + async fn get_size(&self) -> PolarsResult; + /// # Panics + /// Panics if `range` is not in bounds. + async fn get_range(&self, range: Range) -> PolarsResult; + async fn get_ranges(&self, ranges: &[Range]) -> PolarsResult>; +} + +/// Byte source backed by a `MemSlice`, which can potentially be memory-mapped. +pub struct MemSliceByteSource(pub MemSlice); + +impl MemSliceByteSource { + async fn try_new_mmap_from_path( + path: &str, + _cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let file = Arc::new( + tokio::fs::File::open(path) + .await + .map_err(|err| _limit_path_len_io_err(path.as_ref(), err))? + .into_std() + .await, + ); + + Ok(Self(MemSlice::from_file(file.as_ref())?)) + } +} + +impl ByteSource for MemSliceByteSource { + async fn get_size(&self) -> PolarsResult { + Ok(self.0.as_ref().len()) + } + + async fn get_range(&self, range: Range) -> PolarsResult { + let out = self.0.slice(range); + Ok(out) + } + + async fn get_ranges(&self, ranges: &[Range]) -> PolarsResult> { + Ok(ranges + .iter() + .map(|x| self.0.slice(x.clone())) + .collect::>()) + } +} + +pub struct ObjectStoreByteSource { + store: PolarsObjectStore, + path: ObjectStorePath, +} + +impl ObjectStoreByteSource { + async fn try_new_from_path( + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let (CloudLocation { prefix, .. }, store) = + build_object_store(path, cloud_options, false).await?; + let path = object_path_from_str(&prefix)?; + let store = PolarsObjectStore::new(store); + + Ok(Self { store, path }) + } +} + +impl ByteSource for ObjectStoreByteSource { + async fn get_size(&self) -> PolarsResult { + Ok(self.store.head(&self.path).await?.size) + } + + async fn get_range(&self, range: Range) -> PolarsResult { + let bytes = self.store.get_range(&self.path, range).await?; + let mem_slice = MemSlice::from_bytes(bytes); + + Ok(mem_slice) + } + + async fn get_ranges(&self, ranges: &[Range]) -> PolarsResult> { + let ranges = self.store.get_ranges(&self.path, ranges).await?; + Ok(ranges.into_iter().map(MemSlice::from_bytes).collect()) + } +} + +/// Dynamic dispatch to async functions. +pub enum DynByteSource { + MemSlice(MemSliceByteSource), + Cloud(ObjectStoreByteSource), +} + +impl DynByteSource { + pub fn variant_name(&self) -> &str { + match self { + Self::MemSlice(_) => "MemSlice", + Self::Cloud(_) => "Cloud", + } + } +} + +impl Default for DynByteSource { + fn default() -> Self { + Self::MemSlice(MemSliceByteSource(MemSlice::default())) + } +} + +impl ByteSource for DynByteSource { + async fn get_size(&self) -> PolarsResult { + match self { + Self::MemSlice(v) => v.get_size().await, + Self::Cloud(v) => v.get_size().await, + } + } + + async fn get_range(&self, range: Range) -> PolarsResult { + match self { + Self::MemSlice(v) => v.get_range(range).await, + Self::Cloud(v) => v.get_range(range).await, + } + } + + async fn get_ranges(&self, ranges: &[Range]) -> PolarsResult> { + match self { + Self::MemSlice(v) => v.get_ranges(ranges).await, + Self::Cloud(v) => v.get_ranges(ranges).await, + } + } +} + +impl From for DynByteSource { + fn from(value: MemSliceByteSource) -> Self { + Self::MemSlice(value) + } +} + +impl From for DynByteSource { + fn from(value: ObjectStoreByteSource) -> Self { + Self::Cloud(value) + } +} + +#[derive(Clone, Debug)] +pub enum DynByteSourceBuilder { + Mmap, + /// Supports both cloud and local files. + ObjectStore, +} + +impl DynByteSourceBuilder { + pub async fn try_build_from_path( + &self, + path: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + Ok(match self { + Self::Mmap => MemSliceByteSource::try_new_mmap_from_path(path, cloud_options) + .await? + .into(), + Self::ObjectStore => ObjectStoreByteSource::try_new_from_path(path, cloud_options) + .await? + .into(), + }) + } +} diff --git a/crates/polars-io/src/utils/mod.rs b/crates/polars-io/src/utils/mod.rs index 5ed22c76561c..87c80b1b5c5a 100644 --- a/crates/polars-io/src/utils/mod.rs +++ b/crates/polars-io/src/utils/mod.rs @@ -3,6 +3,8 @@ mod other; pub use compression::is_compressed; pub use other::*; +#[cfg(feature = "cloud")] +pub mod byte_source; pub mod slice; pub const URL_ENCODE_CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index 22b3a8d82b18..1984e6ad480e 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -7,6 +7,7 @@ use polars_core::prelude::*; #[cfg(any(feature = "ipc_streaming", feature = "parquet"))] use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df_as_ref}; use polars_error::to_compute_err; +use polars_utils::mmap::MMapSemaphore; use regex::{Regex, RegexBuilder}; use crate::mmap::{MmapBytesReader, ReaderBytes}; @@ -21,12 +22,15 @@ pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( .ok() .and_then(|offset| Some((reader.to_file()?, offset))) { - let mmap = unsafe { memmap::MmapOptions::new().offset(offset).map(file)? }; + let mut options = memmap::MmapOptions::new(); + options.offset(offset); // somehow bck thinks borrows alias // this is sound as file was already bound to 'a use std::fs::File; + let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; + let mmap = MMapSemaphore::new_from_file_with_options(file, options)?; Ok(ReaderBytes::Mapped(mmap, file)) } else { // we can get the bytes for free @@ -89,12 +93,11 @@ pub fn maybe_decompress_bytes<'a>(bytes: &'a [u8], out: &'a mut Vec) -> Pola feature = "avro" ))] pub(crate) fn apply_projection(schema: &ArrowSchema, projection: &[usize]) -> ArrowSchema { - let fields = &schema.fields; - let fields = projection + projection .iter() - .map(|idx| fields[*idx].clone()) - .collect::>(); - ArrowSchema::from(fields) + .map(|idx| schema.get_at_index(*idx).unwrap()) + .map(|(k, v)| (k.clone(), v.clone())) + .collect() } #[cfg(any( @@ -108,26 +111,10 @@ pub(crate) fn columns_to_projection( schema: &ArrowSchema, ) -> PolarsResult> { let mut prj = Vec::with_capacity(columns.len()); - if columns.len() > 100 { - let mut column_names = PlHashMap::with_capacity(schema.fields.len()); - schema.fields.iter().enumerate().for_each(|(i, c)| { - column_names.insert(c.name.as_str(), i); - }); - - for column in columns.iter() { - let Some(&i) = column_names.get(column.as_str()) else { - polars_bail!( - ColumnNotFound: - "unable to find column {:?}; valid columns: {:?}", column, schema.get_names(), - ); - }; - prj.push(i); - } - } else { - for column in columns.iter() { - let i = schema.try_index_of(column)?; - prj.push(i); - } + + for column in columns { + let i = schema.try_index_of(column)?; + prj.push(i); } Ok(prj) @@ -210,7 +197,7 @@ pub static BOOLEAN_RE: Lazy = Lazy::new(|| { }); pub fn materialize_projection( - with_columns: Option<&[String]>, + with_columns: Option<&[PlSmallStr]>, schema: &Schema, hive_partitions: Option<&[Series]>, has_row_index: bool, diff --git a/crates/polars-io/src/utils/slice.rs b/crates/polars-io/src/utils/slice.rs index 78ff29cf1b29..24a3b7dc1ab8 100644 --- a/crates/polars-io/src/utils/slice.rs +++ b/crates/polars-io/src/utils/slice.rs @@ -1,33 +1,58 @@ /// Given a `slice` that is relative to the start of a list of files, calculate the slice to apply /// at a file with a row offset of `current_row_offset`. pub fn split_slice_at_file( - current_row_offset: &mut usize, + current_row_offset_ref: &mut usize, n_rows_this_file: usize, global_slice_start: usize, global_slice_end: usize, ) -> (usize, usize) { - let next_file_offset = *current_row_offset + n_rows_this_file; - // e.g. - // slice: (start: 1, end: 2) - // files: - // 0: (1 row): current_offset: 0, next_file_offset: 1 - // 1: (1 row): current_offset: 1, next_file_offset: 2 - // 2: (1 row): current_offset: 2, next_file_offset: 3 - // in this example we want to include only file 1. - let has_overlap_with_slice = - *current_row_offset < global_slice_end && next_file_offset > global_slice_start; + let current_row_offset = *current_row_offset_ref; + *current_row_offset_ref += n_rows_this_file; + match SplitSlicePosition::split_slice_at_file( + current_row_offset, + n_rows_this_file, + global_slice_start..global_slice_end, + ) { + SplitSlicePosition::Overlapping(offset, len) => (offset, len), + SplitSlicePosition::Before | SplitSlicePosition::After => (0, 0), + } +} + +#[derive(Debug)] +pub enum SplitSlicePosition { + Before, + Overlapping(usize, usize), + After, +} + +impl SplitSlicePosition { + pub fn split_slice_at_file( + current_row_offset: usize, + n_rows_this_file: usize, + global_slice: std::ops::Range, + ) -> Self { + // e.g. + // slice: (start: 1, end: 2) + // files: + // 0: (1 row): current_offset: 0, next_file_offset: 1 + // 1: (1 row): current_offset: 1, next_file_offset: 2 + // 2: (1 row): current_offset: 2, next_file_offset: 3 + // in this example we want to include only file 1. + + let next_row_offset = current_row_offset + n_rows_this_file; - let (rel_start, slice_len) = if !has_overlap_with_slice { - (0, 0) - } else { - let n_rows_to_skip = global_slice_start.saturating_sub(*current_row_offset); - let n_excess_rows = next_file_offset.saturating_sub(global_slice_end); - ( - n_rows_to_skip, - n_rows_this_file - n_rows_to_skip - n_excess_rows, - ) - }; + if next_row_offset <= global_slice.start { + Self::Before + } else if current_row_offset >= global_slice.end { + Self::After + } else { + let n_rows_to_skip = global_slice.start.saturating_sub(current_row_offset); + let n_excess_rows = next_row_offset.saturating_sub(global_slice.end); - *current_row_offset = next_file_offset; - (rel_start, slice_len) + Self::Overlapping( + n_rows_to_skip, + n_rows_this_file - n_rows_to_skip - n_excess_rows, + ) + } + } } diff --git a/crates/polars-json/src/json/deserialize.rs b/crates/polars-json/src/json/deserialize.rs index 9a4c9e27d0cb..eb4c12954a8d 100644 --- a/crates/polars-json/src/json/deserialize.rs +++ b/crates/polars-json/src/json/deserialize.rs @@ -91,9 +91,9 @@ fn deserialize_utf8view_into<'a, A: Borrow>>( fn deserialize_list<'a, A: Borrow>>( rows: &[A], - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> ListArray { - let child = ListArray::::get_child_type(&data_type); + let child = ListArray::::get_child_type(&dtype); let mut validity = MutableBitmap::with_capacity(rows.len()); let mut offsets = Offsets::::with_capacity(rows.len()); @@ -123,18 +123,18 @@ fn deserialize_list<'a, A: Borrow>>( let values = _deserialize(&inner, child.clone()); - ListArray::::new(data_type, offsets.into(), values, validity.into()) + ListArray::::new(dtype, offsets.into(), values, validity.into()) } fn deserialize_struct<'a, A: Borrow>>( rows: &[A], - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> StructArray { - let fields = StructArray::get_fields(&data_type); + let fields = StructArray::get_fields(&dtype); let mut values = fields .iter() - .map(|f| (f.name.as_str(), (f.data_type(), vec![]))) + .map(|f| (f.name.as_str(), (f.dtype(), vec![]))) .collect::>(); let mut validity = MutableBitmap::with_capacity(rows.len()); @@ -160,24 +160,24 @@ fn deserialize_struct<'a, A: Borrow>>( let values = fields .iter() .map(|fld| { - let (data_type, vals) = values.get(fld.name.as_str()).unwrap(); - _deserialize(vals, (*data_type).clone()) + let (dtype, vals) = values.get(fld.name.as_str()).unwrap(); + _deserialize(vals, (*dtype).clone()) }) .collect::>(); - StructArray::new(data_type.clone(), values, validity.into()) + StructArray::new(dtype.clone(), values, validity.into()) } fn fill_array_from( f: fn(&mut MutablePrimitiveArray, &[B]), - data_type: ArrowDataType, + dtype: ArrowDataType, rows: &[B], ) -> Box where T: NativeType, A: From> + Array, { - let mut array = MutablePrimitiveArray::::with_capacity(rows.len()).to(data_type); + let mut array = MutablePrimitiveArray::::with_capacity(rows.len()).to(dtype); f(&mut array, rows); Box::new(A::from(array)) } @@ -248,30 +248,24 @@ where pub(crate) fn _deserialize<'a, A: Borrow>>( rows: &[A], - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> Box { - match &data_type { - ArrowDataType::Null => Box::new(NullArray::new(data_type, rows.len())), + match &dtype { + ArrowDataType::Null => Box::new(NullArray::new(dtype, rows.len())), ArrowDataType::Boolean => { fill_generic_array_from::<_, _, BooleanArray>(deserialize_boolean_into, rows) }, ArrowDataType::Int8 => { - fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, data_type, rows) + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Int16 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) }, - ArrowDataType::Int16 => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), ArrowDataType::Int32 | ArrowDataType::Date32 | ArrowDataType::Time32(_) | ArrowDataType::Interval(IntervalUnit::YearMonth) => { - fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ) + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) }, ArrowDataType::Interval(IntervalUnit::DayTime) => { unimplemented!("There is no natural representation of DayTime in JSON.") @@ -279,73 +273,61 @@ pub(crate) fn _deserialize<'a, A: Borrow>>( ArrowDataType::Int64 | ArrowDataType::Date64 | ArrowDataType::Time64(_) - | ArrowDataType::Duration(_) => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), + | ArrowDataType::Duration(_) => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, ArrowDataType::Timestamp(tu, tz) => { let iter = rows.iter().map(|row| match row.borrow() { BorrowedValue::Static(StaticNode::I64(v)) => Some(*v), BorrowedValue::String(v) => match (tu, tz) { (_, None) => temporal_conversions::utf8_to_naive_timestamp_scalar(v, "%+", tu), (_, Some(ref tz)) => { - let tz = temporal_conversions::parse_offset(tz).unwrap(); + let tz = temporal_conversions::parse_offset(tz.as_str()).unwrap(); temporal_conversions::utf8_to_timestamp_scalar(v, "%+", &tz, tu) }, }, _ => None, }); - Box::new(Int64Array::from_iter(iter).to(data_type)) + Box::new(Int64Array::from_iter(iter).to(dtype)) }, ArrowDataType::UInt8 => { - fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, data_type, rows) + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::UInt16 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::UInt32 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::UInt64 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) }, - ArrowDataType::UInt16 => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), - ArrowDataType::UInt32 => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), - ArrowDataType::UInt64 => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), ArrowDataType::Float16 => unreachable!(), - ArrowDataType::Float32 => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), - ArrowDataType::Float64 => fill_array_from::<_, _, PrimitiveArray>( - deserialize_primitive_into, - data_type, - rows, - ), + ArrowDataType::Float32 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, + ArrowDataType::Float64 => { + fill_array_from::<_, _, PrimitiveArray>(deserialize_primitive_into, dtype, rows) + }, ArrowDataType::LargeUtf8 => { fill_generic_array_from::<_, _, Utf8Array>(deserialize_utf8_into, rows) }, ArrowDataType::Utf8View => { fill_generic_array_from::<_, _, Utf8ViewArray>(deserialize_utf8view_into, rows) }, - ArrowDataType::LargeList(_) => Box::new(deserialize_list(rows, data_type)), + ArrowDataType::LargeList(_) => Box::new(deserialize_list(rows, dtype)), ArrowDataType::LargeBinary => Box::new(deserialize_binary(rows)), - ArrowDataType::Struct(_) => Box::new(deserialize_struct(rows, data_type)), + ArrowDataType::Struct(_) => Box::new(deserialize_struct(rows, dtype)), _ => todo!(), } } -pub fn deserialize(json: &BorrowedValue, data_type: ArrowDataType) -> PolarsResult> { +pub fn deserialize(json: &BorrowedValue, dtype: ArrowDataType) -> PolarsResult> { match json { - BorrowedValue::Array(rows) => match data_type { - ArrowDataType::LargeList(inner) => Ok(_deserialize(rows, inner.data_type)), + BorrowedValue::Array(rows) => match dtype { + ArrowDataType::LargeList(inner) => Ok(_deserialize(rows, inner.dtype)), _ => todo!("read an Array from a non-Array data type"), }, - _ => Ok(_deserialize(&[json], data_type)), + _ => Ok(_deserialize(&[json], dtype)), } } diff --git a/crates/polars-json/src/json/infer_schema.rs b/crates/polars-json/src/json/infer_schema.rs index a525334a3d8c..4d0eb4d47309 100644 --- a/crates/polars-json/src/json/infer_schema.rs +++ b/crates/polars-json/src/json/infer_schema.rs @@ -2,6 +2,7 @@ use std::borrow::Borrow; use arrow::datatypes::{ArrowDataType, Field}; use indexmap::map::Entry; +use polars_utils::pl_str::PlSmallStr; use simd_json::borrowed::Object; use simd_json::{BorrowedValue, StaticNode}; @@ -30,7 +31,7 @@ fn infer_object(inner: &Object) -> PolarsResult { .map(|(key, value)| infer(value).map(|dt| (key, dt))) .map(|maybe_dt| { let (key, dt) = maybe_dt?; - Ok(Field::new(key.as_ref(), dt, true)) + Ok(Field::new(key.as_ref().into(), dt, true)) }) .collect::>>()?; Ok(ArrowDataType::Struct(fields)) @@ -45,13 +46,15 @@ fn infer_array(values: &[BorrowedValue]) -> PolarsResult { let dt = if !types.is_empty() { let types = types.into_iter().collect::>(); - coerce_data_type(&types) + coerce_dtype(&types) } else { ArrowDataType::Null }; Ok(ArrowDataType::LargeList(Box::new(Field::new( - ITEM_NAME, dt, true, + PlSmallStr::from_static(ITEM_NAME), + dt, + true, )))) } @@ -61,7 +64,7 @@ fn infer_array(values: &[BorrowedValue]) -> PolarsResult { /// * Lists and scalars are coerced to a list of a compatible scalar /// * Structs contain the union of all fields /// * All other types are coerced to `Utf8` -pub(crate) fn coerce_data_type>(datatypes: &[A]) -> ArrowDataType { +pub(crate) fn coerce_dtype>(datatypes: &[A]) -> ArrowDataType { use ArrowDataType::*; if datatypes.is_empty() { @@ -94,11 +97,11 @@ pub(crate) fn coerce_data_type>(datatypes: &[A]) -> Arr |mut acc, field| { match acc.entry(field.name.as_str()) { Entry::Occupied(mut v) => { - v.get_mut().insert(&field.data_type); + v.get_mut().insert(&field.dtype); }, Entry::Vacant(v) => { let mut a = PlHashSet::default(); - a.insert(&field.data_type); + a.insert(&field.dtype); v.insert(a); }, } @@ -110,7 +113,7 @@ pub(crate) fn coerce_data_type>(datatypes: &[A]) -> Arr .into_iter() .map(|(name, dts)| { let dts = dts.into_iter().collect::>(); - Field::new(name, coerce_data_type(&dts), true) + Field::new(name.into(), coerce_dtype(&dts), true) }) .collect(); return Struct(fields); @@ -119,41 +122,47 @@ pub(crate) fn coerce_data_type>(datatypes: &[A]) -> Arr .iter() .map(|dt| { if let LargeList(inner) = dt.borrow() { - inner.data_type() + inner.dtype() } else { unreachable!(); } }) .collect(); return LargeList(Box::new(Field::new( - ITEM_NAME, - coerce_data_type(inner_types.as_slice()), + PlSmallStr::from_static(ITEM_NAME), + coerce_dtype(inner_types.as_slice()), true, ))); } else if datatypes.len() > 2 { - return datatypes - .iter() - .map(|dt| dt.borrow().clone()) - .reduce(|a, b| coerce_data_type(&[a, b])) - .unwrap() - .borrow() - .clone(); + return coerce_dtype(datatypes); } let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow()); return match (lhs, rhs) { (lhs, rhs) if lhs == rhs => lhs.clone(), (LargeList(lhs), LargeList(rhs)) => { - let inner = coerce_data_type(&[lhs.data_type(), rhs.data_type()]); - LargeList(Box::new(Field::new(ITEM_NAME, inner, true))) + let inner = coerce_dtype(&[lhs.dtype(), rhs.dtype()]); + LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + inner, + true, + ))) }, (scalar, LargeList(list)) => { - let inner = coerce_data_type(&[scalar, list.data_type()]); - LargeList(Box::new(Field::new(ITEM_NAME, inner, true))) + let inner = coerce_dtype(&[scalar, list.dtype()]); + LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + inner, + true, + ))) }, (LargeList(list), scalar) => { - let inner = coerce_data_type(&[scalar, list.data_type()]); - LargeList(Box::new(Field::new(ITEM_NAME, inner, true))) + let inner = coerce_dtype(&[scalar, list.dtype()]); + LargeList(Box::new(Field::new( + PlSmallStr::from_static(ITEM_NAME), + inner, + true, + ))) }, (Float64, Int64) => Float64, (Int64, Float64) => Float64, diff --git a/crates/polars-json/src/json/write/mod.rs b/crates/polars-json/src/json/write/mod.rs index 3a9bac40a7fe..a23b245b68b2 100644 --- a/crates/polars-json/src/json/write/mod.rs +++ b/crates/polars-json/src/json/write/mod.rs @@ -114,7 +114,7 @@ impl<'a> FallibleStreamingIterator for RecordSerializer<'a> { let mut is_first_row = true; write!(&mut self.buffer, "{{")?; - for (f, ref mut it) in self.schema.fields.iter().zip(self.iterators.iter_mut()) { + for (f, ref mut it) in self.schema.iter_values().zip(self.iterators.iter_mut()) { if !is_first_row { write!(&mut self.buffer, ",")?; } diff --git a/crates/polars-json/src/json/write/serialize.rs b/crates/polars-json/src/json/write/serialize.rs index 872e13970814..2fd5920bd2f2 100644 --- a/crates/polars-json/src/json/write/serialize.rs +++ b/crates/polars-json/src/json/write/serialize.rs @@ -406,7 +406,7 @@ pub(crate) fn new_serializer<'a>( offset: usize, take: usize, ) -> Box + 'a + Send + Sync> { - match array.data_type().to_logical_type() { + match array.dtype().to_logical_type() { ArrowDataType::Boolean => { boolean_serializer(array.as_any().downcast_ref().unwrap(), offset, take) }, diff --git a/crates/polars-json/src/ndjson/deserialize.rs b/crates/polars-json/src/ndjson/deserialize.rs index 35961e96c9a2..d8bab5af157b 100644 --- a/crates/polars-json/src/ndjson/deserialize.rs +++ b/crates/polars-json/src/ndjson/deserialize.rs @@ -15,7 +15,7 @@ use super::*; /// This function errors iff any of the rows is not a valid JSON (i.e. the format is not valid NDJSON). pub fn deserialize_iter<'a>( rows: impl Iterator, - data_type: ArrowDataType, + dtype: ArrowDataType, buf_size: usize, count: usize, ) -> PolarsResult { @@ -23,12 +23,12 @@ pub fn deserialize_iter<'a>( let mut buf = String::with_capacity(std::cmp::min(buf_size + count + 2, u32::MAX as usize)); buf.push('['); - fn _deserializer(s: &mut str, data_type: ArrowDataType) -> PolarsResult> { + fn _deserializer(s: &mut str, dtype: ArrowDataType) -> PolarsResult> { let slice = unsafe { s.as_bytes_mut() }; let out = simd_json::to_borrowed_value(slice) .map_err(|e| PolarsError::ComputeError(format!("json parsing error: '{e}'").into()))?; Ok(if let BorrowedValue::Array(rows) = out { - super::super::json::deserialize::_deserialize(&rows, data_type.clone()) + super::super::json::deserialize::_deserialize(&rows, dtype.clone()) } else { unreachable!() }) @@ -43,7 +43,7 @@ pub fn deserialize_iter<'a>( if buf.len() + next_row_length >= u32::MAX as usize { let _ = buf.pop(); buf.push(']'); - arr.push(_deserializer(&mut buf, data_type.clone())?); + arr.push(_deserializer(&mut buf, dtype.clone())?); buf.clear(); buf.push('['); } @@ -54,9 +54,9 @@ pub fn deserialize_iter<'a>( buf.push(']'); if arr.is_empty() { - _deserializer(&mut buf, data_type.clone()) + _deserializer(&mut buf, dtype.clone()) } else { - arr.push(_deserializer(&mut buf, data_type.clone())?); + arr.push(_deserializer(&mut buf, dtype.clone())?); concatenate_owned_unchecked(&arr) } } diff --git a/crates/polars-json/src/ndjson/file.rs b/crates/polars-json/src/ndjson/file.rs index 3bc2e126fb85..e0a166f934e8 100644 --- a/crates/polars-json/src/ndjson/file.rs +++ b/crates/polars-json/src/ndjson/file.rs @@ -5,7 +5,7 @@ use arrow::datatypes::ArrowDataType; use fallible_streaming_iterator::FallibleStreamingIterator; use indexmap::IndexSet; use polars_error::*; -use polars_utils::aliases::PlIndexSet; +use polars_utils::aliases::{PlIndexSet, PlRandomState}; use simd_json::BorrowedValue; /// Reads up to a number of lines from `reader` into `rows` bounded by `limit`. @@ -90,6 +90,7 @@ fn parse_value<'a>(scratch: &'a mut Vec, val: &[u8]) -> PolarsResult( let rows = vec!["".to_string(); 1]; // 1 <=> read row by row let mut reader = FileReader::new(reader, rows, number_of_rows.map(|v| v.into())); - let mut data_types = PlIndexSet::default(); + let mut dtypes = PlIndexSet::default(); let mut buf = vec![]; while let Some(rows) = reader.next()? { // 0 because it is row by row let value = parse_value(&mut buf, rows[0].as_bytes())?; - let data_type = crate::json::infer(&value)?; - data_types.insert(data_type); + let dtype = crate::json::infer(&value)?; + dtypes.insert(dtype); } - Ok(data_types.into_iter()) + Ok(dtypes.into_iter()) } /// Infers the [`ArrowDataType`] from an iterator of JSON strings. A limited number of @@ -129,17 +130,17 @@ pub fn iter_unique_dtypes( /// # Implementation /// This implementation infers each row by going through the entire iterator. pub fn infer_iter>(rows: impl Iterator) -> PolarsResult { - let mut data_types = IndexSet::<_, ahash::RandomState>::default(); + let mut dtypes = IndexSet::<_, PlRandomState>::default(); let mut buf = vec![]; for row in rows { let v = parse_value(&mut buf, row.as_ref().as_bytes())?; - let data_type = crate::json::infer(&v)?; - if data_type != ArrowDataType::Null { - data_types.insert(data_type); + let dtype = crate::json::infer(&v)?; + if dtype != ArrowDataType::Null { + dtypes.insert(dtype); } } - let v: Vec<&ArrowDataType> = data_types.iter().collect(); - Ok(crate::json::infer_schema::coerce_data_type(&v)) + let v: Vec<&ArrowDataType> = dtypes.iter().collect(); + Ok(crate::json::infer_schema::coerce_dtype(&v)) } diff --git a/crates/polars-json/src/ndjson/mod.rs b/crates/polars-json/src/ndjson/mod.rs index cf98157976a8..4b345362e4ac 100644 --- a/crates/polars-json/src/ndjson/mod.rs +++ b/crates/polars-json/src/ndjson/mod.rs @@ -4,5 +4,4 @@ use polars_error::*; pub mod deserialize; mod file; pub mod write; - pub use file::{infer_iter, iter_unique_dtypes}; diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 91477915c381..03fcc0d8b2c8 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -29,7 +29,6 @@ memchr = { workspace = true } once_cell = { workspace = true } pyo3 = { workspace = true, optional = true } rayon = { workspace = true } -smartstring = { workspace = true } tokio = { workspace = true, optional = true } [dev-dependencies] @@ -217,7 +216,7 @@ arg_where = ["polars-plan/arg_where"] search_sorted = ["polars-plan/search_sorted"] merge_sorted = ["polars-plan/merge_sorted"] meta = ["polars-plan/meta"] -pivot = ["polars-core/rows", "polars-ops/pivot"] +pivot = ["polars-core/rows", "polars-ops/pivot", "polars-plan/pivot"] top_k = ["polars-plan/top_k"] semi_anti_join = ["polars-plan/semi_anti_join"] cse = ["polars-plan/cse", "polars-mem-engine/cse"] @@ -231,6 +230,7 @@ serde = [ "polars-time?/serde", "polars-io/serde", "polars-ops/serde", + "polars-utils/serde", ] fused = ["polars-plan/fused", "polars-ops/fused"] list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs index 3d6a52fe562a..574c2b336407 100644 --- a/crates/polars-lazy/src/dsl/eval.rs +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -8,12 +8,12 @@ use crate::prelude::*; pub(crate) fn eval_field_to_dtype(f: &Field, expr: &Expr, list: bool) -> Field { // Dummy df to determine output dtype. let dtype = f - .data_type() + .dtype() .inner_dtype() .cloned() - .unwrap_or_else(|| f.data_type().clone()); + .unwrap_or_else(|| f.dtype().clone()); - let df = Series::new_empty("", &dtype).into_frame(); + let df = Series::new_empty(PlSmallStr::EMPTY, &dtype).into_frame(); #[cfg(feature = "python")] let out = { @@ -27,12 +27,12 @@ pub(crate) fn eval_field_to_dtype(f: &Field, expr: &Expr, list: bool) -> Field { Ok(out) => { let dtype = out.get_columns()[0].dtype(); if list { - Field::new(f.name(), DataType::List(Box::new(dtype.clone()))) + Field::new(f.name().clone(), DataType::List(Box::new(dtype.clone()))) } else { - Field::new(f.name(), dtype.clone()) + Field::new(f.name().clone(), dtype.clone()) } }, - Err(_) => Field::new(f.name(), DataType::Null), + Err(_) => Field::new(f.name().clone(), DataType::Null), } } @@ -46,15 +46,15 @@ pub trait ExprEvalExtension: IntoExpr + Sized { let this = self.into_expr(); let expr2 = expr.clone(); let func = move |mut s: Series| { - let name = s.name().to_string(); - s.rename(""); + let name = s.name().clone(); + s.rename(PlSmallStr::EMPTY); // Ensure we get the new schema. let output_field = eval_field_to_dtype(s.field().as_ref(), &expr, false); let expr = expr.clone(); let mut arena = Arena::with_capacity(10); - let aexpr = to_expr_ir(expr, &mut arena); + let aexpr = to_expr_ir(expr, &mut arena)?; let phys_expr = create_physical_expr( &aexpr, Context::Default, @@ -107,10 +107,10 @@ pub trait ExprEvalExtension: IntoExpr + Sized { }) .collect::>>()? }; - let s = Series::new(&name, avs); + let s = Series::new(name, avs); - if s.dtype() != output_field.data_type() { - s.cast(output_field.data_type()).map(Some) + if s.dtype() != output_field.dtype() { + s.cast(output_field.dtype()).map(Some) } else { Ok(Some(s)) } diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs index 3af48be7a81b..0cb320eec081 100644 --- a/crates/polars-lazy/src/dsl/functions.rs +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -30,8 +30,8 @@ pub(crate) fn concat_impl>( for lf in &mut inputs[1..] { // Ensure we enable file caching if any lf has it enabled. - if lf.opt_state.contains(OptState::FILE_CACHING) { - opt_state |= OptState::FILE_CACHING; + if lf.opt_state.contains(OptFlags::FILE_CACHING) { + opt_state |= OptFlags::FILE_CACHING; } let lp = std::mem::take(&mut lf.logical_plan); lps.push(lp) @@ -67,8 +67,8 @@ pub fn concat_lf_horizontal>( for lf in &lfs[1..] { // Ensure we enable file caching if any lf has it enabled. - if lf.opt_state.contains(OptState::FILE_CACHING) { - opt_state |= OptState::FILE_CACHING; + if lf.opt_state.contains(OptFlags::FILE_CACHING) { + opt_state |= OptFlags::FILE_CACHING; } } diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs index 34df33c10c50..fb1594196e41 100644 --- a/crates/polars-lazy/src/dsl/list.rs +++ b/crates/polars-lazy/src/dsl/list.rs @@ -50,7 +50,12 @@ fn run_per_sublist( parallel: bool, output_field: Field, ) -> PolarsResult> { - let phys_expr = prepare_expression_for_context("", expr, lst.inner_dtype(), Context::Default)?; + let phys_expr = prepare_expression_for_context( + PlSmallStr::EMPTY, + expr, + lst.inner_dtype(), + Context::Default, + )?; let state = ExecutionState::new(); @@ -72,7 +77,7 @@ fn run_per_sublist( } }) }) - .collect_ca_with_dtype("", output_field.dtype.clone()); + .collect_ca_with_dtype(PlSmallStr::EMPTY, output_field.dtype.clone()); err = m_err.into_inner().unwrap(); ca } else { @@ -99,17 +104,17 @@ fn run_per_sublist( return Err(err); } - ca.rename(s.name()); + ca.rename(s.name().clone()); - if ca.dtype() != output_field.data_type() { - ca.cast(output_field.data_type()).map(Some) + if ca.dtype() != output_field.dtype() { + ca.cast(output_field.dtype()).map(Some) } else { Ok(Some(ca.into_series())) } } fn run_on_group_by_engine( - name: &str, + name: PlSmallStr, lst: &ListChunked, expr: &Expr, ) -> PolarsResult> { @@ -118,19 +123,20 @@ fn run_on_group_by_engine( let groups = offsets_to_groups(arr.offsets()).unwrap(); // List elements in a series. - let values = Series::try_from(("", arr.values().clone())).unwrap(); + let values = Series::try_from((PlSmallStr::EMPTY, arr.values().clone())).unwrap(); let inner_dtype = lst.inner_dtype(); // SAFETY: // Invariant in List means values physicals can be cast to inner dtype let values = unsafe { values.cast_unchecked(inner_dtype).unwrap() }; let df_context = values.into_frame(); - let phys_expr = prepare_expression_for_context("", expr, inner_dtype, Context::Aggregation)?; + let phys_expr = + prepare_expression_for_context(PlSmallStr::EMPTY, expr, inner_dtype, Context::Aggregation)?; let state = ExecutionState::new(); let mut ac = phys_expr.evaluate_on_groups(&df_context, &groups, &state)?; let out = match ac.agg_state() { - AggState::AggregatedScalar(_) | AggState::Literal(_) => { + AggState::AggregatedScalar(_) => { let out = ac.aggregated(); out.as_list().into_series() }, @@ -150,7 +156,7 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized { match e { #[cfg(feature = "dtype-categorical")] Expr::Cast { - data_type: DataType::Categorical(_, _) | DataType::Enum(_, _), + dtype: DataType::Categorical(_, _) | DataType::Enum(_, _), .. } => { polars_bail!( @@ -173,10 +179,13 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized { // ensure we get the new schema let output_field = eval_field_to_dtype(lst.ref_field(), &expr, true); if lst.is_empty() { - return Ok(Some(Series::new_empty(s.name(), output_field.data_type()))); + return Ok(Some(Series::new_empty( + s.name().clone(), + output_field.dtype(), + ))); } if lst.null_count() == lst.len() { - return Ok(Some(s.cast(output_field.data_type())?)); + return Ok(Some(s.cast(output_field.dtype())?)); } let fits_idx_size = lst.get_values_size() <= (IdxSize::MAX as usize); @@ -187,7 +196,7 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized { }; if fits_idx_size && s.null_count() == 0 && !is_user_apply() { - run_on_group_by_engine(s.name(), &lst, &expr) + run_on_group_by_engine(s.name().clone(), &lst, &expr) } else { run_per_sublist(s, &lst, &expr, parallel, output_field) } diff --git a/crates/polars-lazy/src/frame/cached_arenas.rs b/crates/polars-lazy/src/frame/cached_arenas.rs index ecca97c06ac5..3f985b8b1298 100644 --- a/crates/polars-lazy/src/frame/cached_arenas.rs +++ b/crates/polars-lazy/src/frame/cached_arenas.rs @@ -19,7 +19,12 @@ impl LazyFrame { lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - let node = to_alp(self.logical_plan.clone(), expr_arena, lp_arena, false, true)?; + let node = to_alp( + self.logical_plan.clone(), + expr_arena, + lp_arena, + &mut OptFlags::schema_only(), + )?; let schema = lp_arena.get(node).schema(lp_arena).into_owned(); // Cache the logical plan so that next schema call is cheap. @@ -36,7 +41,7 @@ impl LazyFrame { /// /// Returns an `Err` if the logical plan has already encountered an error (i.e., if /// `self.collect()` would fail), `Ok` otherwise. - pub fn schema(&mut self) -> PolarsResult { + pub fn collect_schema(&mut self) -> PolarsResult { let mut cached_arenas = self.cached_arena.lock().unwrap(); match &mut *cached_arenas { @@ -48,8 +53,7 @@ impl LazyFrame { self.logical_plan.clone(), &mut expr_arena, &mut lp_arena, - false, - true, + &mut OptFlags::schema_only(), )?; let schema = lp_arena.get(node).schema(&lp_arena).into_owned(); @@ -83,8 +87,7 @@ impl LazyFrame { self.logical_plan.clone(), &mut arenas.expr_arena, &mut arenas.lp_arena, - false, - true, + &mut OptFlags::schema_only(), )?; let schema = arenas diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index faa5e004c551..b2cfd7267025 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -35,9 +35,9 @@ use polars_expr::{create_physical_expr, ExpressionConversionState}; use polars_io::RowIndex; use polars_mem_engine::{create_physical_plan, Executor}; use polars_ops::frame::JoinCoalesce; -pub use polars_plan::frame::{AllowedOptimizations, OptState}; +pub use polars_plan::frame::{AllowedOptimizations, OptFlags}; use polars_plan::global::FETCH_ROWS; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use crate::frame::cached_arenas::CachedArena; #[cfg(feature = "streaming")] @@ -67,13 +67,14 @@ impl IntoLazy for LazyFrame { } /// Lazy abstraction over an eager `DataFrame`. +/// /// It really is an abstraction over a logical plan. The methods of this struct will incrementally /// modify a logical plan until output is requested (via [`collect`](crate::frame::LazyFrame::collect)). #[derive(Clone, Default)] #[must_use] pub struct LazyFrame { pub logical_plan: DslPlan, - pub(crate) opt_state: OptState, + pub(crate) opt_state: OptFlags, pub(crate) cached_arena: Arc>>, } @@ -81,7 +82,7 @@ impl From for LazyFrame { fn from(plan: DslPlan) -> Self { Self { logical_plan: plan, - opt_state: OptState::default() | OptState::FILE_CACHING, + opt_state: OptFlags::default() | OptFlags::FILE_CACHING, cached_arena: Default::default(), } } @@ -90,7 +91,7 @@ impl From for LazyFrame { impl LazyFrame { pub(crate) fn from_inner( logical_plan: DslPlan, - opt_state: OptState, + opt_state: OptFlags, cached_arena: Arc>>, ) -> Self { Self { @@ -104,11 +105,11 @@ impl LazyFrame { DslBuilder::from(self.logical_plan) } - fn get_opt_state(&self) -> OptState { + fn get_opt_state(&self) -> OptFlags { self.opt_state } - fn from_logical_plan(logical_plan: DslPlan, opt_state: OptState) -> Self { + fn from_logical_plan(logical_plan: DslPlan, opt_state: OptFlags) -> Self { LazyFrame { logical_plan, opt_state, @@ -117,91 +118,93 @@ impl LazyFrame { } /// Get current optimizations. - pub fn get_current_optimizations(&self) -> OptState { + pub fn get_current_optimizations(&self) -> OptFlags { self.opt_state } /// Set allowed optimizations. - pub fn with_optimizations(mut self, opt_state: OptState) -> Self { + pub fn with_optimizations(mut self, opt_state: OptFlags) -> Self { self.opt_state = opt_state; self } /// Turn off all optimizations. pub fn without_optimizations(self) -> Self { - self.with_optimizations(OptState::from_bits_truncate(0) | OptState::TYPE_COERCION) + self.with_optimizations(OptFlags::from_bits_truncate(0) | OptFlags::TYPE_COERCION) } /// Toggle projection pushdown optimization. pub fn with_projection_pushdown(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::PROJECTION_PUSHDOWN, toggle); + self.opt_state.set(OptFlags::PROJECTION_PUSHDOWN, toggle); self } /// Toggle cluster with columns optimization. pub fn with_cluster_with_columns(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::CLUSTER_WITH_COLUMNS, toggle); + self.opt_state.set(OptFlags::CLUSTER_WITH_COLUMNS, toggle); self } /// Toggle predicate pushdown optimization. pub fn with_predicate_pushdown(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::PREDICATE_PUSHDOWN, toggle); + self.opt_state.set(OptFlags::PREDICATE_PUSHDOWN, toggle); self } /// Toggle type coercion optimization. pub fn with_type_coercion(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::TYPE_COERCION, toggle); + self.opt_state.set(OptFlags::TYPE_COERCION, toggle); self } /// Toggle expression simplification optimization on or off. pub fn with_simplify_expr(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::SIMPLIFY_EXPR, toggle); + self.opt_state.set(OptFlags::SIMPLIFY_EXPR, toggle); self } /// Toggle common subplan elimination optimization on or off #[cfg(feature = "cse")] pub fn with_comm_subplan_elim(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::COMM_SUBPLAN_ELIM, toggle); + self.opt_state.set(OptFlags::COMM_SUBPLAN_ELIM, toggle); self } /// Toggle common subexpression elimination optimization on or off #[cfg(feature = "cse")] pub fn with_comm_subexpr_elim(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::COMM_SUBEXPR_ELIM, toggle); + self.opt_state.set(OptFlags::COMM_SUBEXPR_ELIM, toggle); self } /// Toggle slice pushdown optimization. pub fn with_slice_pushdown(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::SLICE_PUSHDOWN, toggle); + self.opt_state.set(OptFlags::SLICE_PUSHDOWN, toggle); self } /// Run nodes that are capably of doing so on the streaming engine. + #[cfg(feature = "streaming")] pub fn with_streaming(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::STREAMING, toggle); + self.opt_state.set(OptFlags::STREAMING, toggle); self } + #[cfg(feature = "new_streaming")] pub fn with_new_streaming(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::NEW_STREAMING, toggle); + self.opt_state.set(OptFlags::NEW_STREAMING, toggle); self } /// Try to estimate the number of rows so that joins can determine which side to keep in memory. pub fn with_row_estimate(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::ROW_ESTIMATE, toggle); + self.opt_state.set(OptFlags::ROW_ESTIMATE, toggle); self } /// Run every node eagerly. This turns off multi-node optimizations. pub fn _with_eager(mut self, toggle: bool) -> Self { - self.opt_state.set(OptState::EAGER, toggle); + self.opt_state.set(OptFlags::EAGER, toggle); self } @@ -285,21 +288,18 @@ impl LazyFrame { /// # use polars_lazy::prelude::*; /// fn sort_by_multiple_columns_with_specific_order(df: DataFrame) -> LazyFrame { /// df.lazy().sort( - /// &["sepal_width", "sepal_length"], + /// ["sepal_width", "sepal_length"], /// SortMultipleOptions::new() /// .with_order_descending_multi([false, true]) /// ) /// } /// ``` /// See [`SortMultipleOptions`] for more options. - pub fn sort(self, by: impl IntoVec, sort_options: SortMultipleOptions) -> Self { + pub fn sort(self, by: impl IntoVec, sort_options: SortMultipleOptions) -> Self { let opt_state = self.get_opt_state(); let lp = self .get_plan_builder() - .sort( - by.into_vec().into_iter().map(|x| col(&x)).collect(), - sort_options, - ) + .sort(by.into_vec().into_iter().map(col).collect(), sort_options) .build(); Self::from_logical_plan(lp, opt_state) } @@ -379,7 +379,7 @@ impl LazyFrame { /// } /// ``` pub fn reverse(self) -> Self { - self.select(vec![col("*").reverse()]) + self.select(vec![col(PlSmallStr::from_static("*")).reverse()]) } /// Rename columns in the DataFrame. @@ -397,8 +397,8 @@ impl LazyFrame { { let iter = existing.into_iter(); let cap = iter.size_hint().0; - let mut existing_vec: Vec = Vec::with_capacity(cap); - let mut new_vec: Vec = Vec::with_capacity(cap); + let mut existing_vec: Vec = Vec::with_capacity(cap); + let mut new_vec: Vec = Vec::with_capacity(cap); // TODO! should this error if `existing` and `new` have different lengths? // Currently, the longer of the two is truncated. @@ -467,7 +467,7 @@ impl LazyFrame { /// /// See the method on [Series](polars_core::series::SeriesTrait::shift) for more info on the `shift` operation. pub fn shift>(self, n: E) -> Self { - self.select(vec![col("*").shift(n.into())]) + self.select(vec![col(PlSmallStr::from_static("*")).shift(n.into())]) } /// Shift the values by a given period and fill the parts that will be empty due to this operation @@ -475,7 +475,9 @@ impl LazyFrame { /// /// See the method on [Series](polars_core::series::SeriesTrait::shift) for more info on the `shift` operation. pub fn shift_and_fill, IE: Into>(self, n: E, fill_value: IE) -> Self { - self.select(vec![col("*").shift_and_fill(n.into(), fill_value.into())]) + self.select(vec![ + col(PlSmallStr::from_static("*")).shift_and_fill(n.into(), fill_value.into()) + ]) } /// Fill None values in the DataFrame with an expression. @@ -506,6 +508,8 @@ impl LazyFrame { let cast_cols: Vec = dtypes .into_iter() .map(|(name, dt)| { + let name = PlSmallStr::from_str(name); + if strict { col(name).strict_cast(dt) } else { @@ -524,9 +528,9 @@ impl LazyFrame { /// Cast all frame columns to the given dtype, resulting in a new LazyFrame pub fn cast_all(self, dtype: DataType, strict: bool) -> Self { self.with_columns(vec![if strict { - col("*").strict_cast(dtype) + col(PlSmallStr::from_static("*")).strict_cast(dtype) } else { - col("*").cast(dtype) + col(PlSmallStr::from_static("*")).cast(dtype) }]) } @@ -565,8 +569,7 @@ impl LazyFrame { self.logical_plan, &mut expr_arena, &mut lp_arena, - true, - true, + &mut self.opt_state, )?; let plan = IRPlan::new(node, lp_arena, expr_arena); Ok(plan) @@ -581,11 +584,20 @@ impl LazyFrame { ) -> PolarsResult { #[allow(unused_mut)] let mut opt_state = self.opt_state; - let streaming = self.opt_state.contains(OptState::STREAMING); + let streaming = self.opt_state.contains(OptFlags::STREAMING); + let new_streaming = self.opt_state.contains(OptFlags::NEW_STREAMING); + #[cfg(feature = "cse")] + if streaming && !new_streaming { + opt_state &= !OptFlags::COMM_SUBPLAN_ELIM; + } + + // The new streaming engine can't deal with the way the common + // subexpression elimination adds length-incorrect with_columns. #[cfg(feature = "cse")] - if streaming && self.opt_state.contains(OptState::COMM_SUBPLAN_ELIM) { - opt_state &= !OptState::COMM_SUBPLAN_ELIM; + if new_streaming { + opt_state &= !OptFlags::COMM_SUBEXPR_ELIM; } + let lp_top = optimize( self.logical_plan, opt_state, @@ -616,7 +628,7 @@ impl LazyFrame { scratch, enable_fmt, true, - opt_state.contains(OptState::ROW_ESTIMATE), + opt_state.contains(OptFlags::ROW_ESTIMATE), )?; } #[cfg(not(feature = "streaming"))] @@ -694,48 +706,46 @@ impl LazyFrame { pub fn collect(self) -> PolarsResult { #[cfg(feature = "new_streaming")] { - let force_new_streaming = self.opt_state.contains(OptState::NEW_STREAMING); - let mut alp_plan = self.to_alp_optimized()?; - let stream_lp_top = alp_plan.lp_arena.add(IR::Sink { - input: alp_plan.lp_top, - payload: SinkType::Memory, - }); - - if force_new_streaming { - return polars_stream::run_query( - stream_lp_top, - alp_plan.lp_arena, - &alp_plan.expr_arena, - ); - } + let auto_new_streaming = + std::env::var("POLARS_AUTO_NEW_STREAMING").as_deref() == Ok("1"); + if self.opt_state.contains(OptFlags::NEW_STREAMING) || auto_new_streaming { + // Try to run using the new streaming engine, falling back + // if it fails in a todo!() error if auto_new_streaming is set. + let mut new_stream_lazy = self.clone(); + new_stream_lazy.opt_state |= OptFlags::NEW_STREAMING; + new_stream_lazy.opt_state &= !OptFlags::STREAMING; + let mut alp_plan = new_stream_lazy.to_alp_optimized()?; + let stream_lp_top = alp_plan.lp_arena.add(IR::Sink { + input: alp_plan.lp_top, + payload: SinkType::Memory, + }); - if std::env::var("POLARS_AUTO_NEW_STREAMING") - .as_deref() - .unwrap_or("") - == "1" - { let f = || { polars_stream::run_query( stream_lp_top, - alp_plan.lp_arena.clone(), - &alp_plan.expr_arena, + alp_plan.lp_arena, + &mut alp_plan.expr_arena, ) }; match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { Ok(r) => return r, Err(e) => { - // Fallback to normal engine if error is due to not being implemented, - // otherwise propagate error. - if e.downcast_ref::<&str>() != Some(&"not yet implemented") { + // Fallback to normal engine if error is due to not being implemented + // and auto_new_streaming is set, otherwise propagate error. + if auto_new_streaming + && e.downcast_ref::<&str>() == Some(&"not yet implemented") + { if polars_core::config::verbose() { eprintln!("caught unimplemented error in new streaming engine, falling back to normal engine"); } + } else { std::panic::resume_unwind(e); } }, } } + let mut alp_plan = self.to_alp_optimized()?; let mut physical_plan = create_physical_plan( alp_plan.lp_top, &mut alp_plan.lp_arena, @@ -828,7 +838,7 @@ impl LazyFrame { cloud_options: Option, ipc_options: IpcWriterOptions, ) -> PolarsResult<()> { - self.opt_state |= OptState::STREAMING; + self.opt_state |= OptFlags::STREAMING; self.logical_plan = DslPlan::Sink { input: Arc::new(self.logical_plan), payload: SinkType::Cloud { @@ -883,7 +893,7 @@ impl LazyFrame { feature = "json", ))] fn sink(mut self, payload: SinkType, msg_alternative: &str) -> Result<(), PolarsError> { - self.opt_state |= OptState::STREAMING; + self.opt_state |= OptFlags::STREAMING; self.logical_plan = DslPlan::Sink { input: Arc::new(self.logical_plan), payload, @@ -911,7 +921,7 @@ impl LazyFrame { /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() /// .filter(col("sepal_width").is_not_null()) - /// .select(&[col("sepal_width"), col("sepal_length")]) + /// .select([col("sepal_width"), col("sepal_length")]) /// } /// ``` pub fn filter(self, predicate: Expr) -> Self { @@ -923,7 +933,7 @@ impl LazyFrame { /// Select (and optionally rename, with [`alias`](crate::dsl::Expr::alias)) columns from the query. /// /// Columns can be selected with [`col`]; - /// If you want to select all columns use `col("*")`. + /// If you want to select all columns use `col(PlSmallStr::from_static("*"))`. /// /// # Example /// @@ -935,14 +945,14 @@ impl LazyFrame { /// /// Column "bar" is renamed to "ham". /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() - /// .select(&[col("foo"), + /// .select([col("foo"), /// col("bar").alias("ham")]) /// } /// /// /// This function selects all columns except "foo" /// fn exclude_a_column(df: DataFrame) -> LazyFrame { /// df.lazy() - /// .select(&[col("*").exclude(["foo"])]) + /// .select([col(PlSmallStr::from_static("*")).exclude(["foo"])]) /// } /// ``` pub fn select>(self, exprs: E) -> Self { @@ -1042,13 +1052,13 @@ impl LazyFrame { mut options: RollingGroupOptions, ) -> LazyGroupBy { if let Expr::Column(name) = index_column { - options.index_column = name.as_ref().into(); + options.index_column = name; } else { let output_field = index_column - .to_field(&self.schema().unwrap(), Context::Default) + .to_field(&self.collect_schema().unwrap(), Context::Default) .unwrap(); return self.with_column(index_column).rolling( - Expr::Column(Arc::from(output_field.name().as_str())), + Expr::Column(output_field.name().clone()), group_by, options, ); @@ -1087,13 +1097,13 @@ impl LazyFrame { mut options: DynamicGroupOptions, ) -> LazyGroupBy { if let Expr::Column(name) = index_column { - options.index_column = name.as_ref().into(); + options.index_column = name; } else { let output_field = index_column - .to_field(&self.schema().unwrap(), Context::Default) + .to_field(&self.collect_schema().unwrap(), Context::Default) .unwrap(); return self.with_column(index_column).group_by_dynamic( - Expr::Column(Arc::from(output_field.name().as_str())), + Expr::Column(output_field.name().clone()), group_by, options, ); @@ -1169,7 +1179,7 @@ impl LazyFrame { /// Creates the Cartesian product from both frames, preserving the order of the left keys. #[cfg(feature = "cross_join")] - pub fn cross_join(self, other: LazyFrame, suffix: Option) -> LazyFrame { + pub fn cross_join(self, other: LazyFrame, suffix: Option) -> LazyFrame { self.join( other, vec![], @@ -1308,8 +1318,8 @@ impl LazyFrame { args: JoinArgs, ) -> LazyFrame { // if any of the nodes reads from files we must activate this this plan as well. - if other.opt_state.contains(OptState::FILE_CACHING) { - self.opt_state |= OptState::FILE_CACHING; + if other.opt_state.contains(OptFlags::FILE_CACHING) { + self.opt_state |= OptFlags::FILE_CACHING; } let left_on = left_on.as_ref().to_vec(); @@ -1512,20 +1522,32 @@ impl LazyFrame { } /// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode). - pub fn explode, IE: Into + Clone>(self, columns: E) -> LazyFrame { + pub fn explode, IE: Into + Clone>(self, columns: E) -> LazyFrame { + self.explode_impl(columns, false) + } + + /// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode). + fn explode_impl, IE: Into + Clone>( + self, + columns: E, + allow_empty: bool, + ) -> LazyFrame { let columns = columns .as_ref() .iter() .map(|e| e.clone().into()) .collect::>(); let opt_state = self.get_opt_state(); - let lp = self.get_plan_builder().explode(columns).build(); + let lp = self + .get_plan_builder() + .explode(columns, allow_empty) + .build(); Self::from_logical_plan(lp, opt_state) } /// Aggregate all the columns as the sum of their null value count. pub fn null_count(self) -> LazyFrame { - self.select(vec![col("*").null_count()]) + self.select(vec![col(PlSmallStr::from_static("*")).null_count()]) } /// Drop non-unique rows and maintain the order of kept rows. @@ -1534,15 +1556,33 @@ impl LazyFrame { /// `None`, all columns are considered. pub fn unique_stable( self, - subset: Option>, + subset: Option>, keep_strategy: UniqueKeepStrategy, ) -> LazyFrame { + self.unique_stable_generic(subset, keep_strategy) + } + + pub fn unique_stable_generic( + self, + subset: Option, + keep_strategy: UniqueKeepStrategy, + ) -> LazyFrame + where + E: AsRef<[IE]>, + IE: Into + Clone, + { + let subset = subset.map(|s| { + s.as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>() + }); + let opt_state = self.get_opt_state(); - let options = DistinctOptions { - subset: subset.map(Arc::new), + let options = DistinctOptionsDSL { + subset, maintain_order: true, keep_strategy, - ..Default::default() }; let lp = self.get_plan_builder().distinct(options).build(); Self::from_logical_plan(lp, opt_state) @@ -1560,12 +1600,25 @@ impl LazyFrame { subset: Option>, keep_strategy: UniqueKeepStrategy, ) -> LazyFrame { + self.unique_generic(subset, keep_strategy) + } + + pub fn unique_generic, IE: Into + Clone>( + self, + subset: Option, + keep_strategy: UniqueKeepStrategy, + ) -> LazyFrame { + let subset = subset.map(|s| { + s.as_ref() + .iter() + .map(|e| e.clone().into()) + .collect::>() + }); let opt_state = self.get_opt_state(); - let options = DistinctOptions { - subset: subset.map(Arc::new), + let options = DistinctOptionsDSL { + subset, maintain_order: false, keep_strategy, - ..Default::default() }; let lp = self.get_plan_builder().distinct(options).build(); Self::from_logical_plan(lp, opt_state) @@ -1620,8 +1673,9 @@ impl LazyFrame { /// Unpivot the DataFrame from wide to long format. /// - /// See [`UnpivotArgs`] for information on how to unpivot a DataFrame. - pub fn unpivot(self, args: UnpivotArgs) -> LazyFrame { + /// See [`UnpivotArgsIR`] for information on how to unpivot a DataFrame. + #[cfg(feature = "pivot")] + pub fn unpivot(self, args: UnpivotArgsDSL) -> LazyFrame { let opt_state = self.get_opt_state(); let lp = self.get_plan_builder().unpivot(args).build(); Self::from_logical_plan(lp, opt_state) @@ -1664,7 +1718,7 @@ impl LazyFrame { function, optimizations, schema, - name.unwrap_or("ANONYMOUS UDF"), + PlSmallStr::from_static(name.unwrap_or("ANONYMOUS UDF")), ) .build(); Self::from_logical_plan(lp, opt_state) @@ -1700,15 +1754,20 @@ impl LazyFrame { /// # Warning /// This can have a negative effect on query performance. This may for instance block /// predicate pushdown optimization. - pub fn with_row_index(mut self, name: &str, offset: Option) -> LazyFrame { + pub fn with_row_index(mut self, name: S, offset: Option) -> LazyFrame + where + S: Into, + { + let name = name.into(); let add_row_index_in_map = match &mut self.logical_plan { DslPlan::Scan { file_options: options, scan_type, .. } if !matches!(scan_type, FileScan::Anonymous { .. }) => { + let name = name.clone(); options.row_index = Some(RowIndex { - name: Arc::from(name), + name, offset: offset.unwrap_or(0), }); false @@ -1717,10 +1776,7 @@ impl LazyFrame { }; if add_row_index_in_map { - self.map_private(DslFunction::RowIndex { - name: Arc::from(name), - offset, - }) + self.map_private(DslFunction::RowIndex { name, offset }) } else { self } @@ -1728,25 +1784,36 @@ impl LazyFrame { /// Return the number of non-null elements for each column. pub fn count(self) -> LazyFrame { - self.select(vec![col("*").count()]) + self.select(vec![col(PlSmallStr::from_static("*")).count()]) } /// Unnest the given `Struct` columns: the fields of the `Struct` type will be /// inserted as columns. #[cfg(feature = "dtype-struct")] - pub fn unnest, S: AsRef>(self, cols: I) -> Self { - self.map_private(DslFunction::FunctionNode(FunctionNode::Unnest { - columns: cols.into_iter().map(|s| Arc::from(s.as_ref())).collect(), - })) + pub fn unnest(self, cols: E) -> Self + where + E: AsRef<[IE]>, + IE: Into + Clone, + { + let cols = cols + .as_ref() + .iter() + .map(|ie| ie.clone().into()) + .collect::>(); + self.map_private(DslFunction::Unnest(cols)) } #[cfg(feature = "merge_sorted")] - pub fn merge_sorted(self, other: LazyFrame, key: &str) -> PolarsResult { + pub fn merge_sorted(self, other: LazyFrame, key: S) -> PolarsResult + where + S: Into, + { // The two DataFrames are temporary concatenated // this indicates until which chunk the data is from the left df // this trick allows us to reuse the `Union` architecture to get map over // two DataFrames - let left = self.map_private(DslFunction::FunctionNode(FunctionNode::Rechunk)); + let key = key.into(); + let left = self.map_private(DslFunction::FunctionIR(FunctionIR::Rechunk)); let q = concat( &[left, other], UnionArgs { @@ -1756,8 +1823,8 @@ impl LazyFrame { }, )?; Ok( - q.map_private(DslFunction::FunctionNode(FunctionNode::MergeSorted { - column: Arc::from(key), + q.map_private(DslFunction::FunctionIR(FunctionIR::MergeSorted { + column: key, })), ) } @@ -1767,7 +1834,7 @@ impl LazyFrame { #[derive(Clone)] pub struct LazyGroupBy { pub logical_plan: DslPlan, - opt_state: OptState, + opt_state: OptFlags, keys: Vec, maintain_order: bool, #[cfg(feature = "dynamic_group_by")] @@ -1790,7 +1857,7 @@ impl LazyGroupBy { /// Group by and aggregate. /// /// Select a column with [col] and choose an aggregation. - /// If you want to aggregate all columns use `col("*")`. + /// If you want to aggregate all columns use `col(PlSmallStr::from_static("*"))`. /// /// # Example /// @@ -1837,8 +1904,13 @@ impl LazyGroupBy { .filter_map(|expr| expr_output_name(expr).ok()) .collect::>(); - self.agg([col("*").exclude(&keys).head(n)]) - .explode([col("*").exclude(&keys)]) + self.agg([col(PlSmallStr::from_static("*")) + .exclude(keys.iter().cloned()) + .head(n)]) + .explode_impl( + [col(PlSmallStr::from_static("*")).exclude(keys.iter().cloned())], + true, + ) } /// Return last n rows of each group @@ -1849,8 +1921,13 @@ impl LazyGroupBy { .filter_map(|expr| expr_output_name(expr).ok()) .collect::>(); - self.agg([col("*").exclude(&keys).tail(n)]) - .explode([col("*").exclude(&keys)]) + self.agg([col(PlSmallStr::from_static("*")) + .exclude(keys.iter().cloned()) + .tail(n)]) + .explode_impl( + [col(PlSmallStr::from_static("*")).exclude(keys.iter().cloned())], + true, + ) } /// Apply a function over the groups as a new DataFrame. @@ -1892,7 +1969,7 @@ pub struct JoinBuilder { right_on: Vec, allow_parallel: bool, force_parallel: bool, - suffix: Option, + suffix: Option, validation: JoinValidation, coalesce: JoinCoalesce, join_nulls: bool, @@ -1978,8 +2055,11 @@ impl JoinBuilder { /// Suffix to add duplicate column names in join. /// Defaults to `"_right"` if this method is never called. - pub fn suffix>(mut self, suffix: S) -> Self { - self.suffix = Some(suffix.as_ref().to_string()); + pub fn suffix(mut self, suffix: S) -> Self + where + S: Into, + { + self.suffix = Some(suffix.into()); self } @@ -1994,9 +2074,9 @@ impl JoinBuilder { let mut opt_state = self.lf.opt_state; let other = self.other.expect("with not set"); - // If any of the nodes reads from files we must activate this this plan as well. - if other.opt_state.contains(OptState::FILE_CACHING) { - opt_state |= OptState::FILE_CACHING; + // If any of the nodes reads from files we must activate this plan as well. + if other.opt_state.contains(OptFlags::FILE_CACHING) { + opt_state |= OptFlags::FILE_CACHING; } let args = JoinArgs { @@ -2026,4 +2106,41 @@ impl JoinBuilder { .build(); LazyFrame::from_logical_plan(lp, opt_state) } + + // Finish with join predicates + pub fn join_where(self, predicates: Vec) -> LazyFrame { + let mut opt_state = self.lf.opt_state; + let other = self.other.expect("with not set"); + + // If any of the nodes reads from files we must activate this plan as well. + if other.opt_state.contains(OptFlags::FILE_CACHING) { + opt_state |= OptFlags::FILE_CACHING; + } + + let args = JoinArgs { + how: self.how, + validation: self.validation, + suffix: self.suffix, + slice: None, + join_nulls: self.join_nulls, + coalesce: self.coalesce, + }; + let options = JoinOptions { + allow_parallel: self.allow_parallel, + force_parallel: self.force_parallel, + args, + ..Default::default() + }; + + let lp = DslPlan::Join { + input_left: Arc::new(self.lf.logical_plan), + input_right: Arc::new(other.logical_plan), + left_on: Default::default(), + right_on: Default::default(), + predicates, + options: Arc::from(options), + }; + + LazyFrame::from_logical_plan(lp, opt_state) + } } diff --git a/crates/polars-lazy/src/frame/pivot.rs b/crates/polars-lazy/src/frame/pivot.rs index 759981c52f0e..4d89eebef010 100644 --- a/crates/polars-lazy/src/frame/pivot.rs +++ b/crates/polars-lazy/src/frame/pivot.rs @@ -1,3 +1,5 @@ +//! Module containing implementation of the pivot operation. +//! //! Polars lazy does not implement a pivot because it is impossible to know the schema without //! materializing the whole dataset. This makes a pivot quite a terrible operation for performant //! workflows. An optimization can never be pushed down passed a pivot. @@ -19,14 +21,19 @@ impl PhysicalAggExpr for PivotExpr { fn evaluate(&self, df: &DataFrame, groups: &GroupsProxy) -> PolarsResult { let state = ExecutionState::new(); let dtype = df.get_columns()[0].dtype(); - let phys_expr = prepare_expression_for_context("", &self.0, dtype, Context::Aggregation)?; + let phys_expr = prepare_expression_for_context( + PlSmallStr::EMPTY, + &self.0, + dtype, + Context::Aggregation, + )?; phys_expr .evaluate_on_groups(df, groups, &state) .map(|mut ac| ac.aggregated()) } - fn root_name(&self) -> PolarsResult<&str> { - Ok("") + fn root_name(&self) -> PolarsResult<&PlSmallStr> { + Ok(PlSmallStr::EMPTY_REF) } } @@ -44,9 +51,9 @@ where I0: IntoIterator, I1: IntoIterator, I2: IntoIterator, - S0: AsRef, - S1: AsRef, - S2: AsRef, + S0: Into, + S1: Into, + S2: Into, { // make sure that the root column is replaced let agg_expr = agg_expr.map(|agg_expr| { @@ -70,9 +77,9 @@ where I0: IntoIterator, I1: IntoIterator, I2: IntoIterator, - S0: AsRef, - S1: AsRef, - S2: AsRef, + S0: Into, + S1: Into, + S2: Into, { // make sure that the root column is replaced let agg_expr = agg_expr.map(|agg_expr| { diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index 46d1304a0b96..005a09186ba2 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -61,7 +61,7 @@ //! assert!(new.column("new_column") //! .unwrap() //! .equals( -//! &Series::new("new_column", &[50, 40, 30, 20, 10]) +//! &Series::new("new_column".into(), &[50, 40, 30, 20, 10]) //! ) //! ); //! ``` @@ -94,7 +94,7 @@ //! assert!(new.column("new_column") //! .unwrap() //! .equals( -//! &Series::new("new_column", &[100, 100, 3, 4, 5]) +//! &Series::new("new_column".into(), &[100, 100, 3, 4, 5]) //! ) //! ); //! ``` @@ -147,7 +147,7 @@ //! col("column_a") //! // apply a custom closure Series => Result //! .map(|_s| { -//! Ok(Some(Series::new("", &[6.0f32, 6.0, 6.0, 6.0, 6.0]))) +//! Ok(Some(Series::new("".into(), &[6.0f32, 6.0, 6.0, 6.0, 6.0]))) //! }, //! // return type of the closure //! GetOutput::from_type(DataType::Float64)).alias("new_column") @@ -206,6 +206,7 @@ pub mod dsl; pub mod frame; pub mod physical_plan; pub mod prelude; + mod scan; #[cfg(test)] mod tests; diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index 0e2a68d9f562..453337e616f8 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -6,14 +6,14 @@ use crate::prelude::*; #[cfg(feature = "pivot")] pub(crate) fn prepare_eval_expr(expr: Expr) -> Expr { expr.map_expr(|e| match e { - Expr::Column(_) => Expr::Column(Arc::from("")), - Expr::Nth(_) => Expr::Column(Arc::from("")), + Expr::Column(_) => Expr::Column(PlSmallStr::EMPTY), + Expr::Nth(_) => Expr::Column(PlSmallStr::EMPTY), e => e, }) } pub(crate) fn prepare_expression_for_context( - name: &str, + name: PlSmallStr, expr: &Expr, dtype: &DataType, ctxt: Context, diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index 2a1afc353183..777f769866d0 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -26,7 +26,7 @@ impl PhysicalIoExpr for Wrap { }; h.evaluate_io(df) } - fn live_variables(&self) -> Option>> { + fn live_variables(&self) -> Option> { // @TODO: This should not unwrap Some(expr_to_leaf_column_names(self.0.as_expression()?)) } @@ -244,13 +244,13 @@ fn get_pipeline_node( // so we just create a scan that returns an empty df let dummy = lp_arena.add(IR::DataFrameScan { df: Arc::new(DataFrame::empty()), - schema: Arc::new(Schema::new()), + schema: Arc::new(Schema::default()), output_schema: None, filter: None, }); IR::MapFunction { - function: FunctionNode::Pipeline { + function: FunctionIR::Pipeline { function: Arc::new(Mutex::new(move |_df: DataFrame| { let state = ExecutionState::new(); if state.verbose() { diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index a014d5f70e5c..7100c083bd47 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -212,7 +212,7 @@ pub(crate) fn insert_streaming_nodes( // Rechunks are ignored MapFunction { input, - function: FunctionNode::Rechunk, + function: FunctionIR::Rechunk, } => { state.streamable = true; stack.push(StackFrame::new(*input, state, current_idx)) @@ -294,7 +294,7 @@ pub(crate) fn insert_streaming_nodes( Scan { .. } => true, MapFunction { input, - function: FunctionNode::Rechunk, + function: FunctionIR::Rechunk, } => matches!(lp_arena.get(*input), Scan { .. }), _ => false, }) => @@ -356,7 +356,7 @@ pub(crate) fn insert_streaming_nodes( #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => fields .iter() - .all(|fld| allowed_dtype(fld.data_type(), string_cache)), + .all(|fld| allowed_dtype(fld.dtype(), string_cache)), // We need to be able to sink to disk or produce the aggregate return dtype. DataType::Unknown(_) => false, #[cfg(feature = "dtype-decimal")] @@ -396,7 +396,7 @@ pub(crate) fn insert_streaming_nodes( let valid_types = || { output_schema - .iter_dtypes() + .iter_values() .all(|dt| allowed_dtype(dt, string_cache)) }; diff --git a/crates/polars-lazy/src/scan/anonymous_scan.rs b/crates/polars-lazy/src/scan/anonymous_scan.rs index 8b83046693da..4c3d9a03e723 100644 --- a/crates/polars-lazy/src/scan/anonymous_scan.rs +++ b/crates/polars-lazy/src/scan/anonymous_scan.rs @@ -42,7 +42,7 @@ impl LazyFrame { .into(); if let Some(rc) = args.row_index { - lf = lf.with_row_index(&rc.name, Some(rc.offset)) + lf = lf.with_row_index(rc.name.clone(), Some(rc.offset)) }; Ok(lf) diff --git a/crates/polars-lazy/src/scan/csv.rs b/crates/polars-lazy/src/scan/csv.rs index 27b5870dab13..998f422820c6 100644 --- a/crates/polars-lazy/src/scan/csv.rs +++ b/crates/polars-lazy/src/scan/csv.rs @@ -5,6 +5,7 @@ use polars_io::cloud::CloudOptions; use polars_io::csv::read::{ infer_file_schema, CommentPrefix, CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues, }; +use polars_io::mmap::ReaderBytes; use polars_io::path_utils::expand_paths; use polars_io::utils::get_reader_bytes; use polars_io::RowIndex; @@ -14,12 +15,12 @@ use crate::prelude::*; #[derive(Clone)] #[cfg(feature = "csv")] pub struct LazyCsvReader { - paths: Arc>, + sources: ScanSources, glob: bool, cache: bool, read_options: CsvReadOptions, cloud_options: Option, - include_file_paths: Option>, + include_file_paths: Option, } #[cfg(feature = "csv")] @@ -30,13 +31,13 @@ impl LazyCsvReader { self } - pub fn new_paths(paths: Arc>) -> Self { - Self::new("").with_paths(paths) + pub fn new_paths(paths: Arc<[PathBuf]>) -> Self { + Self::new_with_sources(ScanSources::Paths(paths)) } - pub fn new(path: impl AsRef) -> Self { + pub fn new_with_sources(sources: ScanSources) -> Self { LazyCsvReader { - paths: Arc::new(vec![path.as_ref().to_path_buf()]), + sources, glob: true, cache: true, read_options: Default::default(), @@ -45,6 +46,10 @@ impl LazyCsvReader { } } + pub fn new(path: impl AsRef) -> Self { + Self::new_with_sources(ScanSources::Paths([path.as_ref().to_path_buf()].into())) + } + /// Skip this number of rows after the header location. #[must_use] pub fn with_skip_rows_after_header(mut self, offset: usize) -> Self { @@ -120,13 +125,13 @@ impl LazyCsvReader { /// Set the comment prefix for this instance. Lines starting with this prefix will be ignored. #[must_use] - pub fn with_comment_prefix(self, comment_prefix: Option<&str>) -> Self { + pub fn with_comment_prefix(self, comment_prefix: Option) -> Self { self.map_parse_options(|opts| { - opts.with_comment_prefix(comment_prefix.map(|s| { + opts.with_comment_prefix(comment_prefix.clone().map(|s| { if s.len() == 1 && s.chars().next().unwrap().is_ascii() { CommentPrefix::Single(s.as_bytes()[0]) } else { - CommentPrefix::Multi(Arc::from(s)) + CommentPrefix::Multi(s) } })) }) @@ -219,38 +224,71 @@ impl LazyCsvReader { where F: Fn(Schema) -> PolarsResult, { - // TODO: Path expansion should happen when converting to the IR - // https://github.com/pola-rs/polars/issues/17634 - let paths = expand_paths(self.paths(), self.glob(), self.cloud_options())?; + let mut n_threads = self.read_options.n_threads; + + let mut infer_schema = |reader_bytes: ReaderBytes| { + let skip_rows = self.read_options.skip_rows; + let parse_options = self.read_options.get_parse_options(); + + PolarsResult::Ok( + infer_file_schema( + &reader_bytes, + parse_options.separator, + self.read_options.infer_schema_length, + self.read_options.has_header, + // we set it to None and modify them after the schema is updated + None, + skip_rows, + self.read_options.skip_rows_after_header, + parse_options.comment_prefix.as_ref(), + parse_options.quote_char, + parse_options.eol_char, + None, + parse_options.try_parse_dates, + self.read_options.raise_if_empty, + &mut n_threads, + parse_options.decimal_comma, + )? + .0, + ) + }; - let Some(path) = paths.first() else { - polars_bail!(ComputeError: "no paths specified for this reader"); + let schema = match self.sources.clone() { + ScanSources::Paths(paths) => { + // TODO: Path expansion should happen when converting to the IR + // https://github.com/pola-rs/polars/issues/17634 + let paths = expand_paths(&paths[..], self.glob(), self.cloud_options())?; + + let Some(path) = paths.first() else { + polars_bail!(ComputeError: "no paths specified for this reader"); + }; + + let mut file = polars_utils::open_file(path)?; + infer_schema(get_reader_bytes(&mut file).expect("could not mmap file"))? + }, + ScanSources::Files(files) => { + let Some(file) = files.first() else { + polars_bail!(ComputeError: "no buffers specified for this reader"); + }; + + infer_schema( + get_reader_bytes(&mut std::io::BufReader::new(file)) + .expect("could not mmap file"), + )? + }, + ScanSources::Buffers(buffers) => { + let Some(buffer) = buffers.first() else { + polars_bail!(ComputeError: "no buffers specified for this reader"); + }; + + infer_schema( + get_reader_bytes(&mut std::io::Cursor::new(buffer)) + .expect("could not mmap file"), + )? + }, }; - let mut file = polars_utils::open_file(path)?; - - let reader_bytes = get_reader_bytes(&mut file).expect("could not mmap file"); - let skip_rows = self.read_options.skip_rows; - let parse_options = self.read_options.get_parse_options(); - - let (schema, _, _) = infer_file_schema( - &reader_bytes, - parse_options.separator, - self.read_options.infer_schema_length, - self.read_options.has_header, - // we set it to None and modify them after the schema is updated - None, - skip_rows, - self.read_options.skip_rows_after_header, - parse_options.comment_prefix.as_ref(), - parse_options.quote_char, - parse_options.eol_char, - None, - parse_options.try_parse_dates, - self.read_options.raise_if_empty, - &mut self.read_options.n_threads, - parse_options.decimal_comma, - )?; + self.read_options.n_threads = n_threads; let mut schema = f(schema)?; // the dtypes set may be for the new names, so update again @@ -263,7 +301,7 @@ impl LazyCsvReader { Ok(self.with_schema(Some(Arc::new(schema)))) } - pub fn with_include_file_paths(mut self, include_file_paths: Option>) -> Self { + pub fn with_include_file_paths(mut self, include_file_paths: Option) -> Self { self.include_file_paths = include_file_paths; self } @@ -273,7 +311,7 @@ impl LazyFileListReader for LazyCsvReader { /// Get the final [LazyFrame]. fn finish(self) -> PolarsResult { let mut lf: LazyFrame = DslBuilder::scan_csv( - self.paths, + self.sources.to_dsl(false), self.read_options, self.cache, self.cloud_options, @@ -282,7 +320,7 @@ impl LazyFileListReader for LazyCsvReader { )? .build() .into(); - lf.opt_state |= OptState::FILE_CACHING; + lf.opt_state |= OptFlags::FILE_CACHING; Ok(lf) } @@ -294,12 +332,12 @@ impl LazyFileListReader for LazyCsvReader { self.glob } - fn paths(&self) -> &[PathBuf] { - &self.paths + fn sources(&self) -> &ScanSources { + &self.sources } - fn with_paths(mut self, paths: Arc>) -> Self { - self.paths = paths; + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; self } diff --git a/crates/polars-lazy/src/scan/file_list_reader.rs b/crates/polars-lazy/src/scan/file_list_reader.rs index 9c716afa060c..28315c96f736 100644 --- a/crates/polars-lazy/src/scan/file_list_reader.rs +++ b/crates/polars-lazy/src/scan/file_list_reader.rs @@ -1,4 +1,5 @@ use std::path::PathBuf; +use std::sync::Arc; use polars_core::prelude::*; use polars_io::cloud::CloudOptions; @@ -18,8 +19,11 @@ pub trait LazyFileListReader: Clone { return self.finish_no_glob(); } - let lfs = self - .paths() + let ScanSources::Paths(paths) = self.sources() else { + unreachable!("opened-files or in-memory buffers should never be globbed"); + }; + + let lfs = paths .iter() .map(|path| { self.clone() @@ -27,7 +31,7 @@ pub trait LazyFileListReader: Clone { .with_n_rows(None) // Each individual reader should not apply a row index. .with_row_index(None) - .with_paths(Arc::new(vec![path.clone()])) + .with_paths([path.clone()].into()) .with_rechunk(false) .finish_no_glob() .map_err(|e| { @@ -40,7 +44,7 @@ pub trait LazyFileListReader: Clone { polars_ensure!( !lfs.is_empty(), - ComputeError: "no matching files found in {:?}", self.paths().iter().map(|x| x.to_str().unwrap()).collect::>() + ComputeError: "no matching files found in {:?}", paths.iter().map(|x| x.to_str().unwrap()).collect::>() ); let mut lf = self.concat_impl(lfs)?; @@ -48,7 +52,7 @@ pub trait LazyFileListReader: Clone { lf = lf.slice(0, n_rows as IdxSize) }; if let Some(rc) = self.row_index() { - lf = lf.with_row_index(&rc.name, Some(rc.offset)) + lf = lf.with_row_index(rc.name.clone(), Some(rc.offset)) }; Ok(lf) @@ -79,11 +83,18 @@ pub trait LazyFileListReader: Clone { true } - fn paths(&self) -> &[PathBuf]; + /// Get the sources for this reader. + fn sources(&self) -> &ScanSources; + + /// Set sources of the scanned files. + #[must_use] + fn with_sources(self, source: ScanSources) -> Self; /// Set paths of the scanned files. #[must_use] - fn with_paths(self, paths: Arc>) -> Self; + fn with_paths(self, paths: Arc<[PathBuf]>) -> Self { + self.with_sources(ScanSources::Paths(paths)) + } /// Configure the row limit. fn with_n_rows(self, n_rows: impl Into>) -> Self; diff --git a/crates/polars-lazy/src/scan/ipc.rs b/crates/polars-lazy/src/scan/ipc.rs index 09391751eae4..a9f8c8b98b0f 100644 --- a/crates/polars-lazy/src/scan/ipc.rs +++ b/crates/polars-lazy/src/scan/ipc.rs @@ -13,10 +13,9 @@ pub struct ScanArgsIpc { pub cache: bool, pub rechunk: bool, pub row_index: Option, - pub memory_map: bool, pub cloud_options: Option, pub hive_options: HiveOptions, - pub include_file_paths: Option>, + pub include_file_paths: Option, } impl Default for ScanArgsIpc { @@ -26,7 +25,6 @@ impl Default for ScanArgsIpc { cache: true, rechunk: false, row_index: None, - memory_map: true, cloud_options: Default::default(), hive_options: Default::default(), include_file_paths: None, @@ -37,29 +35,26 @@ impl Default for ScanArgsIpc { #[derive(Clone)] struct LazyIpcReader { args: ScanArgsIpc, - paths: Arc>, + sources: ScanSources, } impl LazyIpcReader { fn new(args: ScanArgsIpc) -> Self { Self { args, - paths: Arc::new(vec![]), + sources: ScanSources::default(), } } } impl LazyFileListReader for LazyIpcReader { fn finish(self) -> PolarsResult { - let paths = self.paths; let args = self.args; - let options = IpcScanOptions { - memory_map: args.memory_map, - }; + let options = IpcScanOptions {}; let mut lf: LazyFrame = DslBuilder::scan_ipc( - paths, + self.sources.to_dsl(false), options, args.n_rows, args.cache, @@ -71,7 +66,7 @@ impl LazyFileListReader for LazyIpcReader { )? .build() .into(); - lf.opt_state |= OptState::FILE_CACHING; + lf.opt_state |= OptFlags::FILE_CACHING; Ok(lf) } @@ -80,12 +75,12 @@ impl LazyFileListReader for LazyIpcReader { unreachable!() } - fn paths(&self) -> &[PathBuf] { - &self.paths + fn sources(&self) -> &ScanSources { + &self.sources } - fn with_paths(mut self, paths: Arc>) -> Self { - self.paths = paths; + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; self } @@ -125,12 +120,17 @@ impl LazyFileListReader for LazyIpcReader { impl LazyFrame { /// Create a LazyFrame directly from a ipc scan. pub fn scan_ipc(path: impl AsRef, args: ScanArgsIpc) -> PolarsResult { - LazyIpcReader::new(args) - .with_paths(Arc::new(vec![path.as_ref().to_path_buf()])) - .finish() + Self::scan_ipc_sources( + ScanSources::Paths([path.as_ref().to_path_buf()].into()), + args, + ) + } + + pub fn scan_ipc_files(paths: Arc<[PathBuf]>, args: ScanArgsIpc) -> PolarsResult { + Self::scan_ipc_sources(ScanSources::Paths(paths), args) } - pub fn scan_ipc_files(paths: Arc>, args: ScanArgsIpc) -> PolarsResult { - LazyIpcReader::new(args).with_paths(paths).finish() + pub fn scan_ipc_sources(sources: ScanSources, args: ScanArgsIpc) -> PolarsResult { + LazyIpcReader::new(args).with_sources(sources).finish() } } diff --git a/crates/polars-lazy/src/scan/ndjson.rs b/crates/polars-lazy/src/scan/ndjson.rs index 6cb4a8c8cae7..e38270ec3e09 100644 --- a/crates/polars-lazy/src/scan/ndjson.rs +++ b/crates/polars-lazy/src/scan/ndjson.rs @@ -4,8 +4,8 @@ use std::sync::{Arc, Mutex, RwLock}; use polars_core::prelude::*; use polars_io::cloud::CloudOptions; -use polars_io::RowIndex; -use polars_plan::plans::{DslPlan, FileScan}; +use polars_io::{HiveOptions, RowIndex}; +use polars_plan::plans::{DslPlan, FileScan, ScanSources}; use polars_plan::prelude::{FileScanOptions, NDJsonReadOptions}; use crate::prelude::LazyFrame; @@ -13,7 +13,7 @@ use crate::scan::file_list_reader::LazyFileListReader; #[derive(Clone)] pub struct LazyJsonLineReader { - pub(crate) paths: Arc>, + pub(crate) sources: ScanSources, pub(crate) batch_size: Option, pub(crate) low_memory: bool, pub(crate) rechunk: bool, @@ -23,18 +23,18 @@ pub struct LazyJsonLineReader { pub(crate) infer_schema_length: Option, pub(crate) n_rows: Option, pub(crate) ignore_errors: bool, - pub(crate) include_file_paths: Option>, + pub(crate) include_file_paths: Option, pub(crate) cloud_options: Option, } impl LazyJsonLineReader { - pub fn new_paths(paths: Arc>) -> Self { - Self::new(PathBuf::new()).with_paths(paths) + pub fn new_paths(paths: Arc<[PathBuf]>) -> Self { + Self::new_with_sources(ScanSources::Paths(paths)) } - pub fn new(path: impl AsRef) -> Self { + pub fn new_with_sources(sources: ScanSources) -> Self { LazyJsonLineReader { - paths: Arc::new(vec![path.as_ref().to_path_buf()]), + sources, batch_size: None, low_memory: false, rechunk: false, @@ -48,6 +48,11 @@ impl LazyJsonLineReader { cloud_options: None, } } + + pub fn new(path: impl AsRef) -> Self { + Self::new_with_sources(ScanSources::Paths([path.as_ref().to_path_buf()].into())) + } + /// Add a row index column. #[must_use] pub fn with_row_index(mut self, row_index: Option) -> Self { @@ -109,7 +114,7 @@ impl LazyJsonLineReader { self } - pub fn with_include_file_paths(mut self, include_file_paths: Option>) -> Self { + pub fn with_include_file_paths(mut self, include_file_paths: Option) -> Self { self.include_file_paths = include_file_paths; self } @@ -117,8 +122,6 @@ impl LazyJsonLineReader { impl LazyFileListReader for LazyJsonLineReader { fn finish(self) -> PolarsResult { - let paths = Arc::new(Mutex::new((self.paths, false))); - let file_options = FileScanOptions { slice: self.n_rows.map(|x| (0, x)), with_columns: None, @@ -126,7 +129,12 @@ impl LazyFileListReader for LazyJsonLineReader { row_index: self.row_index, rechunk: self.rechunk, file_counter: 0, - hive_options: Default::default(), + hive_options: HiveOptions { + enabled: Some(false), + hive_start_idx: 0, + schema: None, + try_parse_dates: true, + }, glob: true, include_file_paths: self.include_file_paths, }; @@ -147,7 +155,7 @@ impl LazyFileListReader for LazyJsonLineReader { }; Ok(LazyFrame::from(DslPlan::Scan { - paths, + sources: Arc::new(Mutex::new(self.sources.to_dsl(false))), file_info: Arc::new(RwLock::new(None)), hive_parts: None, predicate: None, @@ -160,12 +168,12 @@ impl LazyFileListReader for LazyJsonLineReader { unreachable!(); } - fn paths(&self) -> &[PathBuf] { - &self.paths + fn sources(&self) -> &ScanSources { + &self.sources } - fn with_paths(mut self, paths: Arc>) -> Self { - self.paths = paths; + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; self } diff --git a/crates/polars-lazy/src/scan/parquet.rs b/crates/polars-lazy/src/scan/parquet.rs index ba6563906914..9adb0f1838be 100644 --- a/crates/polars-lazy/src/scan/parquet.rs +++ b/crates/polars-lazy/src/scan/parquet.rs @@ -20,7 +20,7 @@ pub struct ScanArgsParquet { pub cache: bool, /// Expand path given via globbing rules. pub glob: bool, - pub include_file_paths: Option>, + pub include_file_paths: Option, } impl Default for ScanArgsParquet { @@ -44,14 +44,14 @@ impl Default for ScanArgsParquet { #[derive(Clone)] struct LazyParquetReader { args: ScanArgsParquet, - paths: Arc>, + sources: ScanSources, } impl LazyParquetReader { fn new(args: ScanArgsParquet) -> Self { Self { args, - paths: Arc::new(vec![]), + sources: ScanSources::default(), } } } @@ -62,7 +62,7 @@ impl LazyFileListReader for LazyParquetReader { let row_index = self.args.row_index; let mut lf: LazyFrame = DslBuilder::scan_parquet( - self.paths, + self.sources.to_dsl(false), self.args.n_rows, self.args.cache, self.args.parallel, @@ -80,10 +80,10 @@ impl LazyFileListReader for LazyParquetReader { // It's a bit hacky, but this row_index function updates the schema. if let Some(row_index) = row_index { - lf = lf.with_row_index(&row_index.name, Some(row_index.offset)) + lf = lf.with_row_index(row_index.name.clone(), Some(row_index.offset)) } - lf.opt_state |= OptState::FILE_CACHING; + lf.opt_state |= OptFlags::FILE_CACHING; Ok(lf) } @@ -95,12 +95,12 @@ impl LazyFileListReader for LazyParquetReader { unreachable!(); } - fn paths(&self) -> &[PathBuf] { - &self.paths + fn sources(&self) -> &ScanSources { + &self.sources } - fn with_paths(mut self, paths: Arc>) -> Self { - self.paths = paths; + fn with_sources(mut self, sources: ScanSources) -> Self { + self.sources = sources; self } @@ -139,16 +139,19 @@ impl LazyFileListReader for LazyParquetReader { impl LazyFrame { /// Create a LazyFrame directly from a parquet scan. pub fn scan_parquet(path: impl AsRef, args: ScanArgsParquet) -> PolarsResult { - LazyParquetReader::new(args) - .with_paths(Arc::new(vec![path.as_ref().to_path_buf()])) - .finish() + Self::scan_parquet_sources( + ScanSources::Paths([path.as_ref().to_path_buf()].into()), + args, + ) + } + + /// Create a LazyFrame directly from a parquet scan. + pub fn scan_parquet_sources(sources: ScanSources, args: ScanArgsParquet) -> PolarsResult { + LazyParquetReader::new(args).with_sources(sources).finish() } /// Create a LazyFrame directly from a parquet scan. - pub fn scan_parquet_files( - paths: Arc>, - args: ScanArgsParquet, - ) -> PolarsResult { - LazyParquetReader::new(args).with_paths(paths).finish() + pub fn scan_parquet_files(paths: Arc<[PathBuf]>, args: ScanArgsParquet) -> PolarsResult { + Self::scan_parquet_sources(ScanSources::Paths(paths), args) } } diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 0e67cba50566..54387451a8b7 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -6,7 +6,7 @@ use super::*; #[test] #[cfg(feature = "dtype-datetime")] fn test_agg_list_type() -> PolarsResult<()> { - let s = Series::new("foo", &[1, 2, 3]); + let s = Series::new("foo".into(), &[1, 2, 3]); let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; let l = unsafe { s.agg_list(&GroupsProxy::Idx(vec![(0, unitvec![0, 1, 2])].into())) }; diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 8c3f6e5334b2..a1d3f2c050a8 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -136,7 +136,7 @@ fn test_parquet_statistics() -> PolarsResult<()> { // issue: 13427 let out = scan_foods_parquet(par) - .filter(col("calories").is_in(lit(Series::new("", [0, 500])))) + .filter(col("calories").is_in(lit(Series::new("".into(), [0, 500])))) .collect()?; assert_eq!(out.shape(), (0, 4)); @@ -417,7 +417,6 @@ fn test_ipc_globbing() -> PolarsResult<()> { cache: true, rechunk: false, row_index: None, - memory_map: true, cloud_options: None, hive_options: Default::default(), include_file_paths: None, @@ -590,7 +589,7 @@ fn test_row_index_on_files() -> PolarsResult<()> { for offset in [0 as IdxSize, 10] { let lf = LazyCsvReader::new(FOODS_CSV) .with_row_index(Some(RowIndex { - name: Arc::from("index"), + name: PlSmallStr::from_static("index"), offset, })) .finish()?; @@ -665,7 +664,7 @@ fn scan_anonymous_fn_with_options() -> PolarsResult<()> { fn scan(&self, scan_opts: AnonymousScanArgs) -> PolarsResult { assert_eq!(scan_opts.with_columns.clone().unwrap().len(), 2); assert_eq!(scan_opts.n_rows, Some(3)); - let out = fruits_cars().select(scan_opts.with_columns.unwrap().as_ref())?; + let out = fruits_cars().select(scan_opts.with_columns.unwrap().iter().cloned())?; Ok(out.slice(0, scan_opts.n_rows.unwrap())) } } @@ -701,7 +700,7 @@ fn scan_small_dtypes() -> PolarsResult<()> { let df = LazyCsvReader::new(FOODS_CSV) .with_has_header(true) .with_dtype_overwrite(Some(Arc::new(Schema::from_iter([Field::new( - "sugars_g", + "sugars_g".into(), dt.clone(), )])))) .finish()? diff --git a/crates/polars-lazy/src/tests/mod.rs b/crates/polars-lazy/src/tests/mod.rs index 8b1a51212d18..f4ba3e876a65 100644 --- a/crates/polars-lazy/src/tests/mod.rs +++ b/crates/polars-lazy/src/tests/mod.rs @@ -6,14 +6,14 @@ mod cse; mod io; mod logical; mod optimization_checks; +#[cfg(all(feature = "strings", feature = "cse"))] +mod pdsh; mod predicate_queries; mod projection_queries; mod queries; mod schema; #[cfg(feature = "streaming")] mod streaming; -#[cfg(all(feature = "strings", feature = "cse"))] -mod tpch; fn get_arenas() -> (Arena, Arena) { let expr_arena = Arena::with_capacity(16); diff --git a/crates/polars-lazy/src/tests/optimization_checks.rs b/crates/polars-lazy/src/tests/optimization_checks.rs index 2ed1205241bc..e01ad342f061 100644 --- a/crates/polars-lazy/src/tests/optimization_checks.rs +++ b/crates/polars-lazy/src/tests/optimization_checks.rs @@ -65,7 +65,7 @@ pub(crate) fn is_pipeline(q: LazyFrame) -> bool { matches!( lp_arena.get(lp), IR::MapFunction { - function: FunctionNode::Pipeline { .. }, + function: FunctionIR::Pipeline { .. }, .. } ) @@ -79,7 +79,7 @@ pub(crate) fn has_pipeline(q: LazyFrame) -> bool { matches!( lp, IR::MapFunction { - function: FunctionNode::Pipeline { .. }, + function: FunctionIR::Pipeline { .. }, .. } ) @@ -308,7 +308,10 @@ pub fn test_predicate_block_cast() -> PolarsResult<()> { let out = lf.collect()?; let s = out.column("value").unwrap(); - assert_eq!(s, &Series::new("value", [1.0f32, 2.0])); + assert_eq!( + s, + &Series::new(PlSmallStr::from_static("value"), [1.0f32, 2.0]) + ); } Ok(()) @@ -495,8 +498,8 @@ fn test_with_column_prune() -> PolarsResult<()> { matches!(lp, SimpleProjection { .. } | DataFrameScan { .. }) })); assert_eq!( - q.schema().unwrap().as_ref(), - &Schema::from_iter([Field::new("c1", DataType::Int32)]) + q.collect_schema().unwrap().as_ref(), + &Schema::from_iter([Field::new(PlSmallStr::from_static("c1"), DataType::Int32)]) ); Ok(()) } diff --git a/crates/polars-lazy/src/tests/tpch.rs b/crates/polars-lazy/src/tests/pdsh.rs similarity index 83% rename from crates/polars-lazy/src/tests/tpch.rs rename to crates/polars-lazy/src/tests/pdsh.rs index 49eed184f72a..426b19506684 100644 --- a/crates/polars-lazy/src/tests/tpch.rs +++ b/crates/polars-lazy/src/tests/pdsh.rs @@ -1,10 +1,10 @@ -//! The tpch files only got ten rows, so after all the joins filters there is not data +//! The PDSH files only got ten rows, so after all the joins filters there is not data //! Still we can use this to test the schema, operation correctness on empty data, and optimizations //! taken. use super::*; const fn base_path() -> &'static str { - "../../examples/datasets/tpc_heads" + "../../examples/datasets/pds_heads" } fn region() -> LazyFrame { @@ -98,14 +98,14 @@ fn test_q2() -> PolarsResult<()> { let out = q.collect()?; let schema = Schema::from_iter([ - Field::new("s_acctbal", DataType::Float64), - Field::new("s_name", DataType::String), - Field::new("n_name", DataType::String), - Field::new("p_partkey", DataType::Int64), - Field::new("p_mfgr", DataType::String), - Field::new("s_address", DataType::String), - Field::new("s_phone", DataType::String), - Field::new("s_comment", DataType::String), + Field::new("s_acctbal".into(), DataType::Float64), + Field::new("s_name".into(), DataType::String), + Field::new("n_name".into(), DataType::String), + Field::new("p_partkey".into(), DataType::Int64), + Field::new("p_mfgr".into(), DataType::String), + Field::new("s_address".into(), DataType::String), + Field::new("s_phone".into(), DataType::String), + Field::new("s_comment".into(), DataType::String), ]); assert_eq!(&out.schema(), &schema); diff --git a/crates/polars-lazy/src/tests/predicate_queries.rs b/crates/polars-lazy/src/tests/predicate_queries.rs index d3662579051c..855d9463f814 100644 --- a/crates/polars-lazy/src/tests/predicate_queries.rs +++ b/crates/polars-lazy/src/tests/predicate_queries.rs @@ -48,7 +48,7 @@ fn test_issue_2472() -> PolarsResult<()> { .extract(lit(r"(\d+-){4}(\w+)-"), 2) .cast(DataType::Int32) .alias("age"); - let predicate = col("age").is_in(lit(Series::new("", [2i32]))); + let predicate = col("age").is_in(lit(Series::new("".into(), [2i32]))); let out = base .clone() @@ -102,7 +102,7 @@ fn filter_added_column_issue_2470() -> PolarsResult<()> { fn filter_blocked_by_map() -> PolarsResult<()> { let df = fruits_cars(); - let allowed = OptState::default() & !OptState::PREDICATE_PUSHDOWN; + let allowed = OptFlags::default() & !OptFlags::PREDICATE_PUSHDOWN; let q = df .lazy() .map(Ok, allowed, None, None) diff --git a/crates/polars-lazy/src/tests/projection_queries.rs b/crates/polars-lazy/src/tests/projection_queries.rs index 43a6088f4efb..b2cff519c05a 100644 --- a/crates/polars-lazy/src/tests/projection_queries.rs +++ b/crates/polars-lazy/src/tests/projection_queries.rs @@ -128,7 +128,10 @@ fn concat_str_regex_expansion() -> PolarsResult<()> { .select([concat_str([col(r"^b_a_\d$")], ";", false).alias("concatenated")]) .collect()?; let s = out.column("concatenated")?; - assert_eq!(s, &Series::new("concatenated", ["a--;;", ";b--;", ";;c--"])); + assert_eq!( + s, + &Series::new("concatenated".into(), ["a--;;", ";b--;", ";;c--"]) + ); Ok(()) } diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index ade6df69c57e..49d7aa120ea4 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -46,10 +46,11 @@ fn test_lazy_alias() { } #[test] +#[cfg(feature = "pivot")] fn test_lazy_unpivot() { let df = get_df(); - let args = UnpivotArgs { + let args = UnpivotArgsDSL { on: vec!["sepal_length".into(), "sepal_width".into()], index: vec!["petal_width".into(), "petal_length".into()], ..Default::default() @@ -216,7 +217,10 @@ fn test_lazy_ternary_and_predicates() { let new = ldf.collect().unwrap(); let length = new.column("sepal_length").unwrap(); - assert_eq!(length, &Series::new("sepal_length", &[5.1f64, 5.0, 5.4])); + assert_eq!( + length, + &Series::new("sepal_length".into(), &[5.1f64, 5.0, 5.4]) + ); assert_eq!(new.shape(), (3, 6)); } @@ -343,7 +347,7 @@ fn test_lazy_query_8() -> PolarsResult<()> { let mut selection = vec![]; - for c in &["A", "B", "C", "D", "E"] { + for &c in &["A", "B", "C", "D", "E"] { let e = when(col(c).is_in(col("E"))) .then(col("A")) .otherwise(Null {}.lit()) @@ -411,7 +415,7 @@ fn test_lazy_query_10() { use polars_core::export::chrono::Duration as ChronoDuration; let date = NaiveDate::from_ymd_opt(2021, 3, 5).unwrap(); let x: Series = DatetimeChunked::from_naive_datetime( - "x", + "x".into(), [ NaiveDateTime::new(date, NaiveTime::from_hms_opt(12, 0, 0).unwrap()), NaiveDateTime::new(date, NaiveTime::from_hms_opt(13, 0, 0).unwrap()), @@ -421,7 +425,7 @@ fn test_lazy_query_10() { ) .into(); let y: Series = DatetimeChunked::from_naive_datetime( - "y", + "y".into(), [ NaiveDateTime::new(date, NaiveTime::from_hms_opt(11, 0, 0).unwrap()), NaiveDateTime::new(date, NaiveTime::from_hms_opt(11, 0, 0).unwrap()), @@ -437,7 +441,7 @@ fn test_lazy_query_10() { .collect() .unwrap(); let z: Series = DurationChunked::from_duration( - "z", + "z".into(), [ ChronoDuration::try_hours(1).unwrap(), ChronoDuration::try_hours(2).unwrap(), @@ -448,7 +452,7 @@ fn test_lazy_query_10() { .into(); assert!(out.column("z").unwrap().equals(&z)); let x: Series = DatetimeChunked::from_naive_datetime( - "x", + "x".into(), [ NaiveDateTime::new(date, NaiveTime::from_hms_opt(2, 0, 0).unwrap()), NaiveDateTime::new(date, NaiveTime::from_hms_opt(3, 0, 0).unwrap()), @@ -458,7 +462,7 @@ fn test_lazy_query_10() { ) .into(); let y: Series = DatetimeChunked::from_naive_datetime( - "y", + "y".into(), [ NaiveDateTime::new(date, NaiveTime::from_hms_opt(1, 0, 0).unwrap()), NaiveDateTime::new(date, NaiveTime::from_hms_opt(1, 0, 0).unwrap()), @@ -497,8 +501,8 @@ fn test_lazy_query_7() { ]; let data = vec![Some(1.), Some(2.), Some(3.), Some(4.), None, None]; let df = DataFrame::new(vec![ - DatetimeChunked::from_naive_datetime("date", dates, TimeUnit::Nanoseconds).into(), - Series::new("data", data), + DatetimeChunked::from_naive_datetime("date".into(), dates, TimeUnit::Nanoseconds).into(), + Series::new("data".into(), data), ]) .unwrap(); // this tests if predicate pushdown not interferes with the shift data. @@ -519,7 +523,7 @@ fn test_lazy_query_7() { #[test] fn test_lazy_shift_and_fill_all() { let data = &[1, 2, 3]; - let df = DataFrame::new(vec![Series::new("data", data)]).unwrap(); + let df = DataFrame::new(vec![Series::new("data".into(), data)]).unwrap(); let out = df .lazy() .with_column(col("data").shift(lit(1)).fill_null(lit(0)).alias("output")) @@ -557,7 +561,15 @@ fn test_simplify_expr() { let mut expr_arena = Arena::new(); let mut lp_arena = Arena::new(); - let lp_top = to_alp(plan, &mut expr_arena, &mut lp_arena, true, false).unwrap(); + + #[allow(const_item_mutation)] + let lp_top = to_alp( + plan, + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::SIMPLIFY_EXPR, + ) + .unwrap(); let plan = node_to_lp(lp_top, &expr_arena, &mut lp_arena); assert!( matches!(plan, DslPlan::Select{ expr, ..} if matches!(&expr[0], Expr::BinaryExpr{left, ..} if **left == Expr::Literal(LiteralValue::Float(2.0)))) @@ -636,7 +648,7 @@ fn test_type_coercion() { let mut expr_arena = Arena::new(); let mut lp_arena = Arena::new(); - let lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena, true, true).unwrap(); + let lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena, &mut OptFlags::default()).unwrap(); let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); if let DslPlan::Select { expr, .. } = lp { @@ -702,7 +714,7 @@ fn test_lazy_group_by_apply() { df.lazy() .group_by([col("fruits")]) .agg([col("cars").apply( - |s: Series| Ok(Some(Series::new("", &[s.len() as u32]))), + |s: Series| Ok(Some(Series::new("".into(), &[s.len() as u32]))), GetOutput::from_type(DataType::UInt32), )]) .collect() @@ -1154,9 +1166,9 @@ fn test_fill_forward() -> PolarsResult<()> { let agg = out.column("b")?.list()?; let a: Series = agg.get_as_series(0).unwrap(); - assert!(a.equals(&Series::new("b", &[1, 1]))); + assert!(a.equals(&Series::new("b".into(), &[1, 1]))); let a: Series = agg.get_as_series(2).unwrap(); - assert!(a.equals(&Series::new("b", &[1, 1]))); + assert!(a.equals(&Series::new("b".into(), &[1, 1]))); let a: Series = agg.get_as_series(1).unwrap(); assert_eq!(a.null_count(), 1); Ok(()) @@ -1439,7 +1451,7 @@ fn test_when_then_schema() -> PolarsResult<()> { .select([when(col("A").gt(lit(1))) .then(Null {}.lit()) .otherwise(col("A"))]) - .schema(); + .collect_schema(); assert_ne!(schema?.get_at_index(0).unwrap().1, &DataType::Null); Ok(()) @@ -1459,8 +1471,8 @@ fn test_singleton_broadcast() -> PolarsResult<()> { #[test] fn test_list_in_select_context() -> PolarsResult<()> { - let s = Series::new("a", &[1, 2, 3]); - let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name()).unwrap(); + let s = Series::new("a".into(), &[1, 2, 3]); + let mut builder = get_list_builder(s.dtype(), s.len(), 1, s.name().clone()).unwrap(); builder.append_series(&s).unwrap(); let expected = builder.finish().into_series(); @@ -1537,8 +1549,8 @@ fn test_round_after_agg() -> PolarsResult<()> { #[test] #[cfg(feature = "dtype-date")] fn test_fill_nan() -> PolarsResult<()> { - let s0 = Series::new("date", &[1, 2, 3]).cast(&DataType::Date)?; - let s1 = Series::new("float", &[Some(1.0), Some(f32::NAN), Some(3.0)]); + let s0 = Series::new("date".into(), &[1, 2, 3]).cast(&DataType::Date)?; + let s1 = Series::new("float".into(), &[Some(1.0), Some(f32::NAN), Some(3.0)]); let df = DataFrame::new(vec![s0, s1])?; let out = df.lazy().fill_nan(Null {}.lit()).collect()?; @@ -1685,7 +1697,7 @@ fn test_single_ranked_group() -> PolarsResult<()> { #[cfg(feature = "diff")] fn empty_df() -> PolarsResult<()> { let df = fruits_cars(); - let df = df.filter(&BooleanChunked::full("", false, df.height()))?; + let df = df.filter(&BooleanChunked::full("".into(), false, df.height()))?; df.lazy() .select([ @@ -1748,7 +1760,7 @@ fn test_is_in() -> PolarsResult<()> { let out = df .lazy() .group_by_stable([col("fruits")]) - .agg([col("cars").is_in(lit(Series::new("a", ["beetle", "vw"])))]) + .agg([col("cars").is_in(lit(Series::new("a".into(), ["beetle", "vw"])))]) .collect()?; let out = out.column("cars").unwrap(); diff --git a/crates/polars-lazy/src/tests/schema.rs b/crates/polars-lazy/src/tests/schema.rs index c51f15d4b4b7..e4166d33c94a 100644 --- a/crates/polars-lazy/src/tests/schema.rs +++ b/crates/polars-lazy/src/tests/schema.rs @@ -27,13 +27,13 @@ fn test_schema_update_after_projection_pd() -> PolarsResult<()> { assert!(matches!( lp, IR::MapFunction { - function: FunctionNode::Explode { .. }, + function: FunctionIR::Explode { .. }, .. } )); let schema = lp.schema(&lp_arena).into_owned(); - let mut expected = Schema::new(); + let mut expected = Schema::default(); expected.with_column("a".into(), DataType::Int32); expected.with_column("b".into(), DataType::Int32); assert_eq!(schema.as_ref(), &expected); diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index d8d76384ed0c..d76d4c90dc2e 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -264,7 +264,7 @@ fn test_streaming_left_join() -> PolarsResult<()> { #[cfg(feature = "cross_join")] fn test_streaming_slice() -> PolarsResult<()> { let vals = (0..100).collect::>(); - let s = Series::new("", vals); + let s = Series::new("".into(), vals); let lf_a = df![ "a" => s ]? diff --git a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs index 3867012d3f0c..ec4a691eb547 100644 --- a/crates/polars-mem-engine/src/executors/group_by_partitioned.rs +++ b/crates/polars-mem-engine/src/executors/group_by_partitioned.rs @@ -153,7 +153,7 @@ fn estimate_unique_count(keys: &[Series], mut sample_size: usize) -> PolarsResul .map(|s| s.slice(offset, sample_size)) .collect::>(); let df = unsafe { DataFrame::new_no_checks(keys) }; - let names = df.get_column_names(); + let names = df.get_column_names().into_iter().cloned(); let gb = df.group_by(names).unwrap(); Ok(finish(gb.get_groups())) } diff --git a/crates/polars-mem-engine/src/executors/group_by_rolling.rs b/crates/polars-mem-engine/src/executors/group_by_rolling.rs index 437976b103a3..810365b25bc6 100644 --- a/crates/polars-mem-engine/src/executors/group_by_rolling.rs +++ b/crates/polars-mem-engine/src/executors/group_by_rolling.rs @@ -26,7 +26,10 @@ unsafe fn update_keys(keys: &mut [Series], groups: &GroupsProxy) { }, GroupsProxy::Slice { groups, .. } => { for key in keys.iter_mut() { - let indices = groups.iter().map(|[first, _len]| *first).collect_ca(""); + let indices = groups + .iter() + .map(|[first, _len]| *first) + .collect_ca(PlSmallStr::EMPTY); *key = key.take_unchecked(&indices); } }, diff --git a/crates/polars-mem-engine/src/executors/projection_simple.rs b/crates/polars-mem-engine/src/executors/projection_simple.rs index 686321833bd2..c3102d3b7222 100644 --- a/crates/polars-mem-engine/src/executors/projection_simple.rs +++ b/crates/polars-mem-engine/src/executors/projection_simple.rs @@ -6,7 +6,7 @@ pub struct ProjectionSimple { } impl ProjectionSimple { - fn execute_impl(&mut self, df: DataFrame, columns: &[SmartString]) -> PolarsResult { + fn execute_impl(&mut self, df: DataFrame, columns: &[PlSmallStr]) -> PolarsResult { // No duplicate check as that an invariant of this node. df._select_impl_unchecked(columns.as_ref()) } @@ -15,10 +15,10 @@ impl ProjectionSimple { impl Executor for ProjectionSimple { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { state.should_stop()?; - let columns = self.columns.iter_names().cloned().collect::>(); + let columns = self.columns.iter_names_cloned().collect::>(); let profile_name = if state.has_node_timer() { - let name = comma_delimited("simple-projection".to_string(), &columns); + let name = comma_delimited("simple-projection".to_string(), columns.as_slice()); Cow::Owned(name) } else { Cow::Borrowed("") @@ -26,9 +26,9 @@ impl Executor for ProjectionSimple { let df = self.input.execute(state)?; if state.has_node_timer() { - state.record(|| self.execute_impl(df, &columns), profile_name) + state.record(|| self.execute_impl(df, columns.as_slice()), profile_name) } else { - self.execute_impl(df, &columns) + self.execute_impl(df, columns.as_slice()) } } } diff --git a/crates/polars-mem-engine/src/executors/projection_utils.rs b/crates/polars-mem-engine/src/executors/projection_utils.rs index d59e99a94d4a..979c29321cb9 100644 --- a/crates/polars-mem-engine/src/executors/projection_utils.rs +++ b/crates/polars-mem-engine/src/executors/projection_utils.rs @@ -1,11 +1,11 @@ -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use super::*; pub(super) fn profile_name( s: &dyn PhysicalExpr, input_schema: &Schema, -) -> PolarsResult { +) -> PolarsResult { match s.to_field(input_schema) { Err(e) => Err(e), Ok(fld) => Ok(fld.name), diff --git a/crates/polars-mem-engine/src/executors/scan/csv.rs b/crates/polars-mem-engine/src/executors/scan/csv.rs index 936d602afc5f..0ebcb7632ae7 100644 --- a/crates/polars-mem-engine/src/executors/scan/csv.rs +++ b/crates/polars-mem-engine/src/executors/scan/csv.rs @@ -1,4 +1,3 @@ -use std::path::PathBuf; use std::sync::Arc; use polars_core::config; @@ -9,7 +8,7 @@ use polars_core::utils::{ use super::*; pub struct CsvExec { - pub paths: Arc>, + pub sources: ScanSources, pub file_info: FileInfo, pub options: CsvReadOptions, pub file_options: FileScanOptions, @@ -45,7 +44,7 @@ impl CsvExec { .with_row_index(None) .with_path::<&str>(None); - if self.paths.is_empty() { + if self.sources.is_empty() { let out = if let Some(schema) = options_base.schema { DataFrame::from_rows_and_schema(&[], schema.as_ref())? } else { @@ -56,56 +55,31 @@ impl CsvExec { let verbose = config::verbose(); let force_async = config::force_async(); - let run_async = force_async || is_cloud_url(self.paths.first().unwrap()); + let run_async = (self.sources.is_paths() && force_async) || self.sources.is_cloud_url(); - if force_async && verbose { + if self.sources.is_paths() && force_async && verbose { eprintln!("ASYNC READING FORCED"); } let finish_read = |i: usize, options: CsvReadOptions, predicate: Option>| { - let path = &self.paths[i]; - let mut df = if run_async { - #[cfg(feature = "cloud")] - { - let file = polars_io::file_cache::FILE_CACHE - .get_entry(path.to_str().unwrap()) - // Safety: This was initialized by schema inference. - .unwrap() - .try_open_assume_latest()?; - let owned = &mut vec![]; - let mmap = unsafe { memmap::Mmap::map(&file).unwrap() }; - - options - .into_reader_with_file_handle(std::io::Cursor::new( - maybe_decompress_bytes(mmap.as_ref(), owned)?, - )) - ._with_predicate(predicate.clone()) - .finish() - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") - } - } else { - let file = polars_utils::open_file(path)?; - let mmap = unsafe { memmap::Mmap::map(&file).unwrap() }; - let owned = &mut vec![]; - - options - .into_reader_with_file_handle(std::io::Cursor::new(maybe_decompress_bytes( - mmap.as_ref(), - owned, - )?)) - ._with_predicate(predicate.clone()) - .finish() - }?; + let source = self.sources.at(i); + let owned = &mut vec![]; + + let memslice = source.to_memslice_async_latest(run_async)?; + + let reader = std::io::Cursor::new(maybe_decompress_bytes(&memslice, owned)?); + let mut df = options + .into_reader_with_file_handle(reader) + ._with_predicate(predicate.clone()) + .finish()?; if let Some(col) = &self.file_options.include_file_paths { - let path = path.to_str().unwrap(); + let name = source.to_include_path_name(); + unsafe { df.with_column_unchecked( - StringChunked::full(col, path, df.height()).into_series(), + StringChunked::full(col.clone(), name, df.height()).into_series(), ) }; } @@ -123,14 +97,14 @@ impl CsvExec { } let mut n_rows_read = 0usize; - let mut out = Vec::with_capacity(self.paths.len()); + let mut out = Vec::with_capacity(self.sources.len()); // If we have n_rows or row_index then we need to count how many rows we read, so we need // to delay applying the predicate. let predicate_during_read = predicate .clone() .filter(|_| n_rows.is_none() && self.file_options.row_index.is_none()); - for i in 0..self.paths.len() { + for i in 0..self.sources.len() { let opts = options_base .clone() .with_row_index(self.file_options.row_index.clone().map(|mut ri| { @@ -175,10 +149,10 @@ impl CsvExec { if n_rows.is_some() && n_rows_read == n_rows.unwrap() { if verbose { eprintln!( - "reached n_rows = {} at file {} / {}", + "reached n_rows = {} at source {} / {}", n_rows.unwrap(), 1 + i, - self.paths.len() + self.sources.len() ) } break; @@ -203,10 +177,10 @@ impl CsvExec { let dfs = POOL.install(|| { let step = std::cmp::min(POOL.current_num_threads(), 128); - (0..self.paths.len()) + (0..self.sources.len()) .step_by(step) .map(|start| { - (start..std::cmp::min(start.saturating_add(step), self.paths.len())) + (start..std::cmp::min(start.saturating_add(step), self.sources.len())) .into_par_iter() .map(|i| finish_read(i, options_base.clone(), predicate.clone())) .collect::>>() @@ -218,7 +192,7 @@ impl CsvExec { accumulate_dataframes_vertical(dfs.into_iter().flat_map(|dfs| dfs.into_iter()))?; if let Some(row_index) = self.file_options.row_index.clone() { - df.with_row_index_mut(row_index.name.as_ref(), Some(row_index.offset)); + df.with_row_index_mut(row_index.name.clone(), Some(row_index.offset)); } df @@ -235,7 +209,7 @@ impl CsvExec { impl Executor for CsvExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let profile_name = if state.has_node_timer() { - let mut ids = vec![self.paths[0].to_string_lossy().into()]; + let mut ids = vec![self.sources.id()]; if self.predicate.is_some() { ids.push("predicate".into()) } diff --git a/crates/polars-mem-engine/src/executors/scan/ipc.rs b/crates/polars-mem-engine/src/executors/scan/ipc.rs index 574b4c43252b..78b31f268756 100644 --- a/crates/polars-mem-engine/src/executors/scan/ipc.rs +++ b/crates/polars-mem-engine/src/executors/scan/ipc.rs @@ -1,19 +1,20 @@ -use std::path::PathBuf; - use hive::HivePartitions; use polars_core::config; use polars_core::utils::accumulate_dataframes_vertical; +use polars_error::feature_gated; use polars_io::cloud::CloudOptions; use polars_io::path_utils::is_cloud_url; use polars_io::predicates::apply_predicate; +use polars_utils::mmap::MemSlice; use rayon::prelude::*; use super::*; pub struct IpcExec { - pub(crate) paths: Arc>, + pub(crate) sources: ScanSources, pub(crate) file_info: FileInfo, pub(crate) predicate: Option>, + #[allow(dead_code)] pub(crate) options: IpcScanOptions, pub(crate) file_options: FileScanOptions, pub(crate) hive_parts: Option>>, @@ -22,23 +23,20 @@ pub struct IpcExec { impl IpcExec { fn read(&mut self) -> PolarsResult { - let is_cloud = self.paths.iter().any(is_cloud_url); + let is_cloud = match &self.sources { + ScanSources::Paths(paths) => paths.iter().any(is_cloud_url), + ScanSources::Files(_) | ScanSources::Buffers(_) => false, + }; let force_async = config::force_async(); - let mut out = if is_cloud || force_async { - #[cfg(not(feature = "cloud"))] - { - panic!("activate cloud feature") - } - - #[cfg(feature = "cloud")] - { + let mut out = if is_cloud || (self.sources.is_paths() && force_async) { + feature_gated!("cloud", { if force_async && config::verbose() { eprintln!("ASYNC READING FORCED"); } polars_io::pl_async::get_runtime().block_on_potential_spawn(self.read_async())? - } + }) } else { self.read_sync()? }; @@ -50,9 +48,9 @@ impl IpcExec { Ok(out) } - fn read_impl PolarsResult + Send + Sync>( + fn read_impl( &mut self, - path_idx_to_file: F, + idx_to_cached_file: impl Fn(usize) -> Option> + Send + Sync, ) -> PolarsResult { if config::verbose() { eprintln!("executing ipc read sync with row_index = {:?}, n_rows = {:?}, predicate = {:?} for paths {:?}", @@ -62,7 +60,7 @@ impl IpcExec { x.1 }).as_ref(), self.predicate.is_some(), - self.paths + self.sources, ); } @@ -73,26 +71,36 @@ impl IpcExec { self.file_options.row_index.is_some(), ); - let read_path = |path_index: usize, n_rows: Option| { - IpcReader::new(path_idx_to_file(path_index)?) + let read_path = |index: usize, n_rows: Option| { + let source = self.sources.at(index); + + let memslice = match source { + ScanSourceRef::Path(path) => { + let file = match idx_to_cached_file(index) { + None => std::fs::File::open(path)?, + Some(f) => f?, + }; + + MemSlice::from_file(&file)? + }, + ScanSourceRef::File(file) => MemSlice::from_file(file)?, + ScanSourceRef::Buffer(buff) => MemSlice::from_bytes(buff.clone()), + }; + + IpcReader::new(std::io::Cursor::new(memslice)) .with_n_rows(n_rows) .with_row_index(self.file_options.row_index.clone()) .with_projection(projection.clone()) .with_hive_partition_columns( self.hive_parts .as_ref() - .map(|x| x[path_index].materialize_partition_columns()), + .map(|x| x[index].materialize_partition_columns()), ) - .with_include_file_path(self.file_options.include_file_paths.as_ref().map(|x| { - ( - x.clone(), - Arc::from(self.paths[path_index].to_str().unwrap().to_string()), - ) - })) - .memory_mapped( - self.options - .memory_map - .then(|| self.paths[path_index].clone()), + .with_include_file_path( + self.file_options + .include_file_paths + .as_ref() + .map(|x| (x.clone(), Arc::from(source.to_include_path_name()))), ) .finish() }; @@ -101,9 +109,9 @@ impl IpcExec { assert_eq!(x.0, 0); x.1 }) { - let mut out = Vec::with_capacity(self.paths.len()); + let mut out = Vec::with_capacity(self.sources.len()); - for i in 0..self.paths.len() { + for i in 0..self.sources.len() { let df = read_path(i, Some(n_rows))?; let df_height = df.height(); out.push(df); @@ -121,7 +129,7 @@ impl IpcExec { out } else { POOL.install(|| { - (0..self.paths.len()) + (0..self.sources.len()) .into_par_iter() .map(|i| read_path(i, None)) .collect::>>() @@ -157,8 +165,7 @@ impl IpcExec { } fn read_sync(&mut self) -> PolarsResult { - let paths = self.paths.clone(); - self.read_impl(move |i| std::fs::File::open(&paths[i]).map_err(Into::into)) + self.read_impl(|_| None) } #[cfg(feature = "cloud")] @@ -167,9 +174,11 @@ impl IpcExec { // concurrently. use polars_io::file_cache::init_entries_from_uri_list; + let paths = self.sources.into_paths().unwrap(); + tokio::task::block_in_place(|| { let cache_entries = init_entries_from_uri_list( - self.paths + paths .iter() .map(|x| Arc::from(x.to_str().unwrap())) .collect::>() @@ -177,7 +186,7 @@ impl IpcExec { self.cloud_options.as_ref(), )?; - self.read_impl(move |i| cache_entries[i].try_open_check_latest()) + self.read_impl(|i| Some(cache_entries[i].try_open_check_latest())) }) } } @@ -185,7 +194,7 @@ impl IpcExec { impl Executor for IpcExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let profile_name = if state.has_node_timer() { - let mut ids = vec![self.paths[0].to_string_lossy().into()]; + let mut ids = vec![self.sources.id()]; if self.predicate.is_some() { ids.push("predicate".into()) } diff --git a/crates/polars-mem-engine/src/executors/scan/mod.rs b/crates/polars-mem-engine/src/executors/scan/mod.rs index ddc1f1b4e6e1..1b46d40b9044 100644 --- a/crates/polars-mem-engine/src/executors/scan/mod.rs +++ b/crates/polars-mem-engine/src/executors/scan/mod.rs @@ -38,7 +38,7 @@ type Predicate = Option>; #[cfg(any(feature = "ipc", feature = "parquet"))] fn prepare_scan_args( predicate: Option>, - with_columns: &mut Option>, + with_columns: &mut Option>, schema: &mut SchemaRef, has_row_index: bool, hive_partitions: Option<&[Series]>, @@ -62,7 +62,7 @@ fn prepare_scan_args( pub struct DataFrameExec { pub(crate) df: Arc, pub(crate) filter: Option>, - pub(crate) projection: Option>, + pub(crate) projection: Option>, pub(crate) predicate_has_windows: bool, } @@ -74,7 +74,7 @@ impl Executor for DataFrameExec { // projection should be before selection as those are free // TODO: this is only the case if we don't create new columns if let Some(projection) = &self.projection { - df = df.select(projection.as_slice())?; + df = df.select(projection.iter().cloned())?; } if let Some(selection) = &self.filter { diff --git a/crates/polars-mem-engine/src/executors/scan/ndjson.rs b/crates/polars-mem-engine/src/executors/scan/ndjson.rs index 5e17a289eac7..a662760fd54b 100644 --- a/crates/polars-mem-engine/src/executors/scan/ndjson.rs +++ b/crates/polars-mem-engine/src/executors/scan/ndjson.rs @@ -1,12 +1,10 @@ -use std::path::PathBuf; - use polars_core::config; use polars_core::utils::accumulate_dataframes_vertical; use super::*; pub struct JsonExec { - paths: Arc>, + sources: ScanSources, options: NDJsonReadOptions, file_scan_options: FileScanOptions, file_info: FileInfo, @@ -15,14 +13,14 @@ pub struct JsonExec { impl JsonExec { pub fn new( - paths: Arc>, + sources: ScanSources, options: NDJsonReadOptions, file_scan_options: FileScanOptions, file_info: FileInfo, predicate: Option>, ) -> Self { Self { - paths, + sources, options, file_scan_options, file_info, @@ -41,9 +39,9 @@ impl JsonExec { let verbose = config::verbose(); let force_async = config::force_async(); - let run_async = force_async || is_cloud_url(self.paths.first().unwrap()); + let run_async = (self.sources.is_paths() && force_async) || self.sources.is_cloud_url(); - if force_async && verbose { + if self.sources.is_paths() && force_async && verbose { eprintln!("ASYNC READING FORCED"); } @@ -56,57 +54,38 @@ impl JsonExec { if n_rows == Some(0) { let mut df = DataFrame::empty_with_schema(schema); if let Some(col) = &self.file_scan_options.include_file_paths { - unsafe { df.with_column_unchecked(StringChunked::full_null(col, 0).into_series()) }; + unsafe { + df.with_column_unchecked(StringChunked::full_null(col.clone(), 0).into_series()) + }; } if let Some(row_index) = &self.file_scan_options.row_index { - df.with_row_index_mut(row_index.name.as_ref(), Some(row_index.offset)); + df.with_row_index_mut(row_index.name.clone(), Some(row_index.offset)); } return Ok(df); } let dfs = self - .paths + .sources .iter() - .map_while(|p| { + .map_while(|source| { if n_rows == Some(0) { return None; } - let file = if run_async { - #[cfg(feature = "cloud")] - { - match polars_io::file_cache::FILE_CACHE - .get_entry(p.to_str().unwrap()) - // Safety: This was initialized by schema inference. - .unwrap() - .try_open_assume_latest() - { - Ok(v) => v, - Err(e) => return Some(Err(e)), - } - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") - } - } else { - match polars_utils::open_file(p.as_ref()) { - Ok(v) => v, - Err(e) => return Some(Err(e)), - } + let row_index = self.file_scan_options.row_index.as_mut(); + + let memslice = match source.to_memslice_async_latest(run_async) { + Ok(memslice) => memslice, + Err(err) => return Some(Err(err)), }; - let mmap = unsafe { memmap::Mmap::map(&file).unwrap() }; let owned = &mut vec![]; - let curs = - std::io::Cursor::new(match maybe_decompress_bytes(mmap.as_ref(), owned) { - Ok(v) => v, - Err(e) => return Some(Err(e)), - }); + let curs = std::io::Cursor::new(match maybe_decompress_bytes(&memslice, owned) { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }); let reader = JsonLineReader::new(curs); - let row_index = self.file_scan_options.row_index.as_mut(); - let df = reader .with_schema(schema.clone()) .with_rechunk(self.file_scan_options.rechunk) @@ -129,10 +108,10 @@ impl JsonExec { } if let Some(col) = &self.file_scan_options.include_file_paths { - let path = p.to_str().unwrap(); + let name = source.to_include_path_name(); unsafe { df.with_column_unchecked( - StringChunked::full(col, path, df.height()).into_series(), + StringChunked::full(col.clone(), name, df.height()).into_series(), ) }; } @@ -148,7 +127,7 @@ impl JsonExec { impl Executor for JsonExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let profile_name = if state.has_node_timer() { - let ids = vec![self.paths[0].to_string_lossy().into()]; + let ids = vec![self.sources.id()]; let name = comma_delimited("ndjson".to_string(), &ids); Cow::Owned(name) } else { diff --git a/crates/polars-mem-engine/src/executors/scan/parquet.rs b/crates/polars-mem-engine/src/executors/scan/parquet.rs index a78dbf113151..a37fc7c42f33 100644 --- a/crates/polars-mem-engine/src/executors/scan/parquet.rs +++ b/crates/polars-mem-engine/src/executors/scan/parquet.rs @@ -1,20 +1,18 @@ -use std::path::PathBuf; - use hive::HivePartitions; use polars_core::config; #[cfg(feature = "cloud")] use polars_core::config::{get_file_prefetch_size, verbose}; use polars_core::utils::accumulate_dataframes_vertical; +use polars_error::feature_gated; use polars_io::cloud::CloudOptions; use polars_io::parquet::metadata::FileMetaDataRef; -use polars_io::path_utils::is_cloud_url; use polars_io::utils::slice::split_slice_at_file; use polars_io::RowIndex; use super::*; pub struct ParquetExec { - paths: Arc>, + sources: ScanSources, file_info: FileInfo, hive_parts: Option>>, predicate: Option>, @@ -29,7 +27,7 @@ pub struct ParquetExec { impl ParquetExec { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - paths: Arc>, + sources: ScanSources, file_info: FileInfo, hive_parts: Option>>, predicate: Option>, @@ -39,7 +37,7 @@ impl ParquetExec { metadata: Option, ) -> Self { ParquetExec { - paths, + sources, file_info, hive_parts, predicate, @@ -52,7 +50,7 @@ impl ParquetExec { fn read_par(&mut self) -> PolarsResult> { let parallel = match self.options.parallel { - ParallelStrategy::Auto if self.paths.len() > POOL.current_num_threads() => { + ParallelStrategy::Auto if self.sources.len() > POOL.current_num_threads() => { ParallelStrategy::RowGroups }, identity => identity, @@ -62,7 +60,7 @@ impl ParquetExec { let step = std::cmp::min(POOL.current_num_threads(), 128); // Modified if we have a negative slice - let mut first_file = 0; + let mut first_source = 0; // (offset, end) let (slice_offset, slice_end) = if let Some(slice) = self.file_options.slice { @@ -75,15 +73,16 @@ impl ParquetExec { let mut cum_rows = 0; let chunk_size = 8; POOL.install(|| { - for path_indexes in (0..self.paths.len()) + for path_indexes in (0..self.sources.len()) .rev() .collect::>() .chunks(chunk_size) { let row_counts = path_indexes .into_par_iter() - .map(|i| { - ParquetReader::new(std::fs::File::open(&self.paths[*i])?).num_rows() + .map(|&i| { + let memslice = self.sources.at(i).to_memslice()?; + ParquetReader::new(std::io::Cursor::new(memslice)).num_rows() }) .collect::>>()?; @@ -91,12 +90,12 @@ impl ParquetExec { cum_rows += rc; if cum_rows >= slice_start_as_n_from_end { - first_file = *path_idx; + first_source = *path_idx; break; } } - if first_file > 0 { + if first_source > 0 { break; } } @@ -125,10 +124,8 @@ impl ParquetExec { let base_row_index = self.file_options.row_index.take(); // Limit no. of files at a time to prevent open file limits. - for i in (first_file..self.paths.len()).step_by(step) { - let end = std::cmp::min(i.saturating_add(step), self.paths.len()); - let paths = &self.paths[i..end]; - let hive_parts = self.hive_parts.as_ref().map(|x| &x[i..end]); + for i in (first_source..self.sources.len()).step_by(step) { + let end = std::cmp::min(i.saturating_add(step), self.sources.len()); if current_offset >= slice_end && !result.is_empty() { return Ok(result); @@ -137,11 +134,13 @@ impl ParquetExec { // First initialize the readers, predicates and metadata. // This will be used to determine the slices. That way we can actually read all the // files in parallel even if we add row index columns or slices. - let iter = (0..paths.len()).into_par_iter().map(|i| { - let path = &paths[i]; - let hive_partitions = hive_parts.map(|x| x[i].materialize_partition_columns()); + let iter = (i..end).into_par_iter().map(|i| { + let source = self.sources.at(i); + let hive_partitions = self + .hive_parts + .as_ref() + .map(|x| x[i].materialize_partition_columns()); - let file = std::fs::File::open(path)?; let (projection, predicate) = prepare_scan_args( self.predicate.clone(), &mut self.file_options.with_columns.clone(), @@ -150,7 +149,9 @@ impl ParquetExec { hive_partitions.as_deref(), ); - let mut reader = ParquetReader::new(file) + let memslice = source.to_memslice()?; + + let mut reader = ParquetReader::new(std::io::Cursor::new(memslice)) .read_parallel(parallel) .set_low_memory(self.options.low_memory) .use_statistics(self.options.use_statistics) @@ -160,7 +161,7 @@ impl ParquetExec { self.file_options .include_file_paths .as_ref() - .map(|x| (x.clone(), Arc::from(paths[i].to_str().unwrap()))), + .map(|x| (x.clone(), Arc::from(source.to_include_path_name()))), ); reader @@ -187,15 +188,14 @@ impl ParquetExec { readers_and_metadata .into_par_iter() .zip(row_statistics.into_par_iter()) - .enumerate() .map( - |(i, ((reader, _, predicate, projection), (cumulative_read, slice)))| { + |((reader, _, predicate, projection), (cumulative_read, slice))| { let row_index = base_row_index.as_ref().map(|rc| RowIndex { name: rc.name.clone(), offset: rc.offset + cumulative_read as IdxSize, }); - let mut df = reader + let df = reader .with_slice(Some(slice)) .with_row_index(row_index) .with_predicate(predicate.clone()) @@ -210,20 +210,6 @@ impl ParquetExec { )? .finish()?; - if let Some(col) = &self.file_options.include_file_paths { - let path = paths[i].to_str().unwrap(); - unsafe { - df.with_column_unchecked( - StringChunked::full( - col, - path, - std::cmp::max(df.height(), slice.1), - ) - .into_series(), - ) - }; - } - Ok(df) }, ) @@ -236,6 +222,7 @@ impl ParquetExec { result.extend_from_slice(&out) } } + Ok(result) } @@ -246,6 +233,7 @@ impl ParquetExec { use polars_io::utils::slice::split_slice_at_file; let verbose = verbose(); + let paths = self.sources.into_paths().unwrap(); let first_metadata = &self.metadata; let cloud_options = self.cloud_options.as_ref(); @@ -269,13 +257,13 @@ impl ParquetExec { let slice_start_as_n_from_end = -slice.0 as usize; let mut cum_rows = 0; - let paths = &self.paths; + let paths = &paths; let cloud_options = Arc::new(self.cloud_options.clone()); let paths = paths.clone(); let cloud_options = cloud_options.clone(); - let mut iter = stream::iter((0..self.paths.len()).rev().map(|i| { + let mut iter = stream::iter((0..paths.len()).rev().map(|i| { let paths = paths.clone(); let cloud_options = cloud_options.clone(); @@ -327,9 +315,9 @@ impl ParquetExec { let base_row_index = self.file_options.row_index.take(); let mut processed = 0; - for batch_start in (first_file_idx..self.paths.len()).step_by(batch_size) { - let end = std::cmp::min(batch_start.saturating_add(batch_size), self.paths.len()); - let paths = &self.paths[batch_start..end]; + for batch_start in (first_file_idx..paths.len()).step_by(batch_size) { + let end = std::cmp::min(batch_start.saturating_add(batch_size), paths.len()); + let paths = &paths[batch_start..end]; let hive_parts = self.hive_parts.as_ref().map(|x| &x[batch_start..end]); if current_offset >= slice_end && !result.is_empty() { @@ -340,7 +328,7 @@ impl ParquetExec { eprintln!( "querying metadata of {}/{} files...", processed, - self.paths.len() + paths.len() ); } @@ -386,7 +374,7 @@ impl ParquetExec { let include_file_paths = self.file_options.include_file_paths.as_ref(); if verbose { - eprintln!("reading of {}/{} file...", processed, self.paths.len()); + eprintln!("reading of {}/{} file...", processed, paths.len()); } let iter = readers_and_metadata @@ -462,23 +450,17 @@ impl ParquetExec { .and_then(|_| self.predicate.take()) .map(phys_expr_to_io_expr); - let is_cloud = is_cloud_url(self.paths.first().unwrap()); + let is_cloud = self.sources.is_cloud_url(); let force_async = config::force_async(); - let out = if is_cloud || force_async { - #[cfg(not(feature = "cloud"))] - { - panic!("activate cloud feature") - } - - #[cfg(feature = "cloud")] - { + let out = if is_cloud || (self.sources.is_paths() && force_async) { + feature_gated!("cloud", { if force_async && config::verbose() { eprintln!("ASYNC READING FORCED"); } polars_io::pl_async::get_runtime().block_on_potential_spawn(self.read_async())? - } + }) } else { self.read_par()? }; @@ -497,7 +479,7 @@ impl ParquetExec { impl Executor for ParquetExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let profile_name = if state.has_node_timer() { - let mut ids = vec![self.paths[0].to_string_lossy().into()]; + let mut ids = vec![self.sources.id()]; if self.predicate.is_some() { ids.push("predicate".into()) } diff --git a/crates/polars-mem-engine/src/executors/scan/python_scan.rs b/crates/polars-mem-engine/src/executors/scan/python_scan.rs index 1b44453b088d..270c52ea963c 100644 --- a/crates/polars-mem-engine/src/executors/scan/python_scan.rs +++ b/crates/polars-mem-engine/src/executors/scan/python_scan.rs @@ -68,7 +68,12 @@ impl Executor for PythonScanExec { self.options.python_source, PythonScanSource::Pyarrow | PythonScanSource::Cuda ) { - let args = (python_scan_function, with_columns, predicate, n_rows); + let args = ( + python_scan_function, + with_columns.map(|x| x.into_iter().map(|x| x.to_string()).collect::>()), + predicate, + n_rows, + ); callable.call1(args).map_err(to_compute_err) } else { // If there are filters, take smaller chunks to ensure we can keep memory @@ -80,7 +85,7 @@ impl Executor for PythonScanExec { }; let args = ( python_scan_function, - with_columns, + with_columns.map(|x| x.into_iter().map(|x| x.to_string()).collect::>()), predicate, n_rows, batch_size, diff --git a/crates/polars-mem-engine/src/executors/sort.rs b/crates/polars-mem-engine/src/executors/sort.rs index 820cdb65fdfd..23374abea7ac 100644 --- a/crates/polars-mem-engine/src/executors/sort.rs +++ b/crates/polars-mem-engine/src/executors/sort.rs @@ -1,3 +1,5 @@ +use polars_utils::format_pl_smallstr; + use super::*; pub(crate) struct SortExec { @@ -29,9 +31,14 @@ impl SortExec { // therefore we rename more complex expressions so that // polars core does not match these. if !matches!(e.as_expression(), Some(&Expr::Column(_))) { - s.rename(&format!("_POLARS_SORT_BY_{i}")); + s.rename(format_pl_smallstr!("_POLARS_SORT_BY_{i}")); } - polars_ensure!(s.len() == height, ShapeMismatch: "sort expressions must have same length as DataFrame, got DataFrame height: {} and Series length: {}", height, s.len()); + polars_ensure!( + s.len() == height, + ShapeMismatch: "sort expressions must have same \ + length as DataFrame, got DataFrame height: {} and Series length: {}", + height, s.len() + ); Ok(s) }) .collect::>>()?; diff --git a/crates/polars-mem-engine/src/executors/stack.rs b/crates/polars-mem-engine/src/executors/stack.rs index 212c098c10ca..3425c129fef2 100644 --- a/crates/polars-mem-engine/src/executors/stack.rs +++ b/crates/polars-mem-engine/src/executors/stack.rs @@ -33,22 +33,8 @@ impl StackExec { self.has_windows, self.options.run_parallel, )?; - if !self.options.should_broadcast { - debug_assert!( - res.iter() - .all(|column| column.name().starts_with("__POLARS_CSER_0x")), - "non-broadcasting hstack should only be used for CSE columns" - ); - // Safety: this case only appears as a result - // of CSE optimization, and the usage there - // produces new, unique column names. It is - // immediately followed by a projection which - // pulls out the possibly mismatching column - // lengths. - unsafe { df.get_columns_mut().extend(res) }; - } else { - df._add_columns(res, schema)?; - } + // We don't have to do a broadcast check as cse is not allowed to hit this. + df._add_columns(res, schema)?; Ok(df) }); diff --git a/crates/polars-mem-engine/src/executors/udf.rs b/crates/polars-mem-engine/src/executors/udf.rs index 0916a5493ace..f5bea3a8cfee 100644 --- a/crates/polars-mem-engine/src/executors/udf.rs +++ b/crates/polars-mem-engine/src/executors/udf.rs @@ -2,7 +2,7 @@ use super::*; pub(crate) struct UdfExec { pub(crate) input: Box, - pub(crate) function: FunctionNode, + pub(crate) function: FunctionIR, } impl Executor for UdfExec { diff --git a/crates/polars-mem-engine/src/executors/unique.rs b/crates/polars-mem-engine/src/executors/unique.rs index 34be31149938..69c7b19c528a 100644 --- a/crates/polars-mem-engine/src/executors/unique.rs +++ b/crates/polars-mem-engine/src/executors/unique.rs @@ -2,7 +2,7 @@ use super::*; pub(crate) struct UniqueExec { pub(crate) input: Box, - pub(crate) options: DistinctOptions, + pub(crate) options: DistinctOptionsIR, } impl Executor for UniqueExec { @@ -15,7 +15,11 @@ impl Executor for UniqueExec { } } let df = self.input.execute(state)?; - let subset = self.options.subset.as_ref().map(|v| &***v); + let subset = self + .options + .subset + .as_ref() + .map(|v| v.iter().cloned().collect::>()); let keep = self.options.keep_strategy; state.record( @@ -24,10 +28,12 @@ impl Executor for UniqueExec { return Ok(df); } - match self.options.maintain_order { - true => df.unique_stable(subset, keep, self.options.slice), - false => df.unique(subset, keep, self.options.slice), - } + df.unique_impl( + self.options.maintain_order, + subset, + keep, + self.options.slice, + ) }, Cow::Borrowed("unique()"), ) diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 31be11ae93b6..e1b53bea2151 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -244,7 +244,7 @@ fn create_physical_plan_impl( if streamable { // This can cause problems with string caches streamable = !input_schema - .iter_dtypes() + .iter_values() .any(|dt| dt.contains_categoricals()) || { #[cfg(feature = "dtype-categorical")] @@ -276,7 +276,7 @@ fn create_physical_plan_impl( }, #[allow(unused_variables)] Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -306,7 +306,7 @@ fn create_physical_plan_impl( match scan_type { #[cfg(feature = "csv")] FileScan::Csv { options, .. } => Ok(Box::new(executors::CsvExec { - paths, + sources, file_info, options, predicate, @@ -318,7 +318,7 @@ fn create_physical_plan_impl( cloud_options, metadata, } => Ok(Box::new(executors::IpcExec { - paths, + sources, file_info, predicate, options, @@ -332,7 +332,7 @@ fn create_physical_plan_impl( cloud_options, metadata, } => Ok(Box::new(executors::ParquetExec::new( - paths, + sources, file_info, hive_parts, predicate, @@ -343,7 +343,7 @@ fn create_physical_plan_impl( ))), #[cfg(feature = "json")] FileScan::NDJson { options, .. } => Ok(Box::new(executors::JsonExec::new( - paths, + sources, options, file_options, file_info, @@ -375,7 +375,8 @@ fn create_physical_plan_impl( state.expr_depth, ); - let streamable = all_streamable(&expr, expr_arena, Context::Default); + let streamable = + options.should_broadcast && all_streamable(&expr, expr_arena, Context::Default); let phys_expr = create_physical_expressions_from_irs( &expr, Context::Default, @@ -429,7 +430,7 @@ fn create_physical_plan_impl( .transpose()?; Ok(Box::new(executors::DataFrameExec { df, - projection: output_schema.map(|s| s.iter_names().cloned().collect()), + projection: output_schema.map(|s| s.iter_names_cloned().collect()), filter: selection, predicate_has_windows: state.has_windows, })) @@ -629,7 +630,8 @@ fn create_physical_plan_impl( let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; - let streamable = all_streamable(&exprs, expr_arena, Context::Default); + let streamable = + options.should_broadcast && all_streamable(&exprs, expr_arena, Context::Default); let mut state = ExpressionConversionState::new( POOL.current_num_threads() > exprs.len(), diff --git a/crates/polars-mem-engine/src/utils.rs b/crates/polars-mem-engine/src/utils.rs index cb04d599a7f0..91bd0e17902a 100644 --- a/crates/polars-mem-engine/src/utils.rs +++ b/crates/polars-mem-engine/src/utils.rs @@ -1,22 +1,28 @@ -use std::path::PathBuf; +use std::path::Path; pub(crate) use polars_plan::plans::ArenaLpIter; -use polars_plan::plans::IR; +use polars_plan::plans::{ScanSources, IR}; use polars_utils::aliases::PlHashSet; use polars_utils::arena::{Arena, Node}; /// Get a set of the data source paths in this LogicalPlan -pub(crate) fn agg_source_paths( +/// +/// # Notes +/// +/// - Scan sources with opened files or in-memory buffers are ignored. +pub(crate) fn agg_source_paths<'a>( root_lp: Node, - acc_paths: &mut PlHashSet, - lp_arena: &Arena, + acc_paths: &mut PlHashSet<&'a Path>, + lp_arena: &'a Arena, ) { - lp_arena.iter(root_lp).for_each(|(_, lp)| { - use IR::*; - if let Scan { paths, .. } = lp { - for path in paths.as_ref() { - acc_paths.insert(path.clone()); + for (_, lp) in lp_arena.iter(root_lp) { + if let IR::Scan { sources, .. } = lp { + match sources { + ScanSources::Paths(paths) => acc_paths.extend(paths.iter().map(|p| p.as_path())), + ScanSources::Buffers(_) | ScanSources::Files(_) => { + // Ignore + }, } } - }) + } } diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 3bbdb10fcaf0..0782f188b1df 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -13,6 +13,7 @@ polars-compute = { workspace = true } polars-core = { workspace = true, features = ["algorithm_group_by", "zip_with"] } polars-error = { workspace = true } polars-json = { workspace = true, optional = true } +polars-schema = { workspace = true } polars-utils = { workspace = true } ahash = { workspace = true } @@ -35,7 +36,6 @@ rayon = { workspace = true } regex = { workspace = true } serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } -smartstring = { workspace = true } unicode-reverse = { workspace = true, optional = true } [dependencies.jsonpath_lib] @@ -79,10 +79,11 @@ business = ["dtype-date", "chrono"] fused = [] cutqcut = ["dtype-categorical", "dtype-struct"] rle = ["dtype-struct"] -timezones = ["chrono-tz", "chrono"] +timezones = ["chrono", "chrono-tz", "polars-core/temporal", "polars-core/timezones", "polars-core/dtype-datetime"] random = ["rand", "rand_distr"] rank = ["rand"] find_many = ["aho-corasick"] +serde = ["dep:serde", "polars-core/serde", "polars-utils/serde", "polars-schema/serde"] # extra utilities for BinaryChunked binary_encoding = ["base64", "hex"] @@ -112,7 +113,7 @@ mode = [] search_sorted = [] merge_sorted = [] top_k = [] -pivot = ["polars-core/reinterpret"] +pivot = ["polars-core/reinterpret", "polars-core/dtype-struct"] cross_join = [] chunked_ids = [] asof_join = [] @@ -123,7 +124,7 @@ list_gather = [] list_sets = [] list_any_all = [] list_drop_nulls = [] -list_sample = [] +list_sample = ["polars-core/random"] extract_groups = ["dtype-struct", "polars-core/regex"] is_in = ["polars-core/reinterpret"] hist = ["dtype-categorical", "dtype-struct"] diff --git a/crates/polars-ops/src/chunked_array/array/any_all.rs b/crates/polars-ops/src/chunked_array/array/any_all.rs index 49bb3872d05d..8f9bd175c8ca 100644 --- a/crates/polars-ops/src/chunked_array/array/any_all.rs +++ b/crates/polars-ops/src/chunked_array/array/any_all.rs @@ -10,7 +10,7 @@ where { let values = arr.values(); - polars_ensure!(values.data_type() == &ArrowDataType::Boolean, ComputeError: "expected boolean elements in array"); + polars_ensure!(values.dtype() == &ArrowDataType::Boolean, ComputeError: "expected boolean elements in array"); let values = values.as_any().downcast_ref::().unwrap(); let validity = arr.validity().cloned(); @@ -43,12 +43,12 @@ pub(super) fn array_all(ca: &ArrayChunked) -> PolarsResult { let chunks = ca .downcast_iter() .map(|arr| array_all_any(arr, arrow::compute::boolean::all, true)); - Ok(BooleanChunked::try_from_chunk_iter(ca.name(), chunks)?.into_series()) + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) } pub(super) fn array_any(ca: &ArrayChunked) -> PolarsResult { let chunks = ca .downcast_iter() .map(|arr| array_all_any(arr, arrow::compute::boolean::any, false)); - Ok(BooleanChunked::try_from_chunk_iter(ca.name(), chunks)?.into_series()) + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) } diff --git a/crates/polars-ops/src/chunked_array/array/count.rs b/crates/polars-ops/src/chunked_array/array/count.rs index 528a9750306c..ef54e7b70591 100644 --- a/crates/polars-ops/src/chunked_array/array/count.rs +++ b/crates/polars-ops/src/chunked_array/array/count.rs @@ -8,7 +8,7 @@ use super::*; #[cfg(feature = "array_count")] pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult { - let value = Series::new("", [value]); + let value = Series::new(PlSmallStr::EMPTY, [value]); let ca = ca.apply_to_inner(&|s| { ChunkCompare::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) diff --git a/crates/polars-ops/src/chunked_array/array/dispersion.rs b/crates/polars-ops/src/chunked_array/array/dispersion.rs index 056b1b87d09a..17924d7c38bb 100644 --- a/crates/polars-ops/src/chunked_array/array/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/array/dispersion.rs @@ -5,24 +5,24 @@ pub(super) fn median_with_nulls(ca: &ArrayChunked) -> PolarsResult { DataType::Float32 => { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(*tu).into_series() }, _ => { let out: Float64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median())) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, }; - out.rename(ca.name()); + out.rename(ca.name().clone()); Ok(out) } @@ -31,14 +31,14 @@ pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(*tu).into_series() }, _ => { @@ -50,7 +50,7 @@ pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult PolarsResult { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, #[cfg(feature = "dtype-duration")] DataType::Duration(TimeUnit::Milliseconds) => { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(TimeUnit::Milliseconds).into_series() }, #[cfg(feature = "dtype-duration")] @@ -80,16 +80,16 @@ pub(super) fn var_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { let out: Float64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, }; - out.rename(ca.name()); + out.rename(ca.name().clone()); Ok(out) } diff --git a/crates/polars-ops/src/chunked_array/array/get.rs b/crates/polars-ops/src/chunked_array/array/get.rs index 46bf7232e390..1df931f165f7 100644 --- a/crates/polars-ops/src/chunked_array/array/get.rs +++ b/crates/polars-ops/src/chunked_array/array/get.rs @@ -11,7 +11,7 @@ fn array_get_literal(ca: &ArrayChunked, idx: i64, null_on_oob: bool) -> PolarsRe .downcast_iter() .map(|arr| sub_fixed_size_list_get_literal(arr, idx, null_on_oob)) .collect::>>()?; - Series::try_from((ca.name(), chunks)) + Series::try_from((ca.name().clone(), chunks)) .unwrap() .cast(ca.inner_dtype()) } @@ -31,7 +31,11 @@ pub fn array_get( if let Some(index) = index { array_get_literal(ca, index, null_on_oob) } else { - Ok(Series::full_null(ca.name(), ca.len(), ca.inner_dtype())) + Ok(Series::full_null( + ca.name().clone(), + ca.len(), + ca.inner_dtype(), + )) } }, len if len == ca.len() => { @@ -65,5 +69,5 @@ where .zip(rhs.downcast_iter()) .map(|(lhs_arr, rhs_arr)| op(lhs_arr, rhs_arr, null_on_oob)) .collect::>>()?; - Series::try_from((lhs.name(), chunks)) + Series::try_from((lhs.name().clone(), chunks)) } diff --git a/crates/polars-ops/src/chunked_array/array/join.rs b/crates/polars-ops/src/chunked_array/array/join.rs index 0ba4a517ca0f..426adb32826b 100644 --- a/crates/polars-ops/src/chunked_array/array/join.rs +++ b/crates/polars-ops/src/chunked_array/array/join.rs @@ -12,7 +12,7 @@ fn join_literal( }; let mut buf = String::with_capacity(128); - let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); ca.for_each_amortized(|opt_s| { let opt_val = opt_s.and_then(|s| { @@ -45,7 +45,7 @@ fn join_many( ignore_nulls: bool, ) -> PolarsResult { let mut buf = String::new(); - let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); { ca.amortized_iter() } .zip(separator) @@ -88,7 +88,7 @@ pub fn array_join( DataType::String => match separator.len() { 1 => match separator.get(0) { Some(separator) => join_literal(ca, separator, ignore_nulls), - _ => Ok(StringChunked::full_null(ca.name(), ca.len())), + _ => Ok(StringChunked::full_null(ca.name().clone(), ca.len())), }, _ => join_many(ca, separator, ignore_nulls), }, diff --git a/crates/polars-ops/src/chunked_array/array/min_max.rs b/crates/polars-ops/src/chunked_array/array/min_max.rs index bdeb76f250aa..a82de2436291 100644 --- a/crates/polars-ops/src/chunked_array/array/min_max.rs +++ b/crates/polars-ops/src/chunked_array/array/min_max.rs @@ -68,7 +68,7 @@ where } pub(super) fn array_dispatch( - name: &str, + name: PlSmallStr, values: &Series, width: usize, agg_type: AggType, diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index 1fa813be05a9..909ef5db8f6d 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -23,7 +23,7 @@ pub fn has_inner_nulls(ca: &ArrayChunked) -> bool { fn get_agg(ca: &ArrayChunked, agg_type: AggType) -> Series { let values = ca.get_inner(); let width = ca.width(); - min_max::array_dispatch(ca.name(), &values, width, agg_type) + min_max::array_dispatch(ca.name().clone(), &values, width, agg_type) } pub trait ArrayNameSpace: AsArray { @@ -149,7 +149,7 @@ pub trait ArrayNameSpace: AsArray { unsafe { ca.apply_amortized_same_type(|s| s.as_ref().shift(n)) } } else { ArrayChunked::full_null_with_dtype( - ca.name(), + ca.name().clone(), ca.len(), ca.inner_dtype(), ca.width(), diff --git a/crates/polars-ops/src/chunked_array/array/sum_mean.rs b/crates/polars-ops/src/chunked_array/array/sum_mean.rs index 60bd144317bc..27261a33eba0 100644 --- a/crates/polars-ops/src/chunked_array/array/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/array/sum_mean.rs @@ -53,7 +53,7 @@ pub(super) fn sum_array_numerical(ca: &ArrayChunked, inner_type: &DataType) -> S }) .collect::>(); - Series::try_from((ca.name(), chunks)).unwrap() + Series::try_from((ca.name().clone(), chunks)).unwrap() } pub(super) fn sum_with_nulls(ca: &ArrayChunked, inner_dtype: &DataType) -> PolarsResult { @@ -115,6 +115,6 @@ pub(super) fn sum_with_nulls(ca: &ArrayChunked, inner_dtype: &DataType) -> Polar }, } }; - out.rename(ca.name()); + out.rename(ca.name().clone()); Ok(out) } diff --git a/crates/polars-ops/src/chunked_array/array/to_struct.rs b/crates/polars-ops/src/chunked_array/array/to_struct.rs index 980135bcb169..b79a9ffcfe9f 100644 --- a/crates/polars-ops/src/chunked_array/array/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/array/to_struct.rs @@ -1,14 +1,14 @@ use polars_core::export::rayon::prelude::*; use polars_core::POOL; -use polars_utils::format_smartstring; -use smartstring::alias::String as SmartString; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; use super::*; -pub type ArrToStructNameGenerator = Arc SmartString + Send + Sync>; +pub type ArrToStructNameGenerator = Arc PlSmallStr + Send + Sync>; -pub fn arr_default_struct_name_gen(idx: usize) -> SmartString { - format_smartstring!("field_{idx}") +pub fn arr_default_struct_name_gen(idx: usize) -> PlSmallStr { + format_pl_smallstr!("field_{idx}") } pub trait ToStruct: AsArray { @@ -28,16 +28,19 @@ pub trait ToStruct: AsArray { (0..n_fields) .into_par_iter() .map(|i| { - ca.array_get(&Int64Chunked::from_slice("", &[i as i64]), true) - .map(|mut s| { - s.rename(&name_generator(i)); - s - }) + ca.array_get( + &Int64Chunked::from_slice(PlSmallStr::EMPTY, &[i as i64]), + true, + ) + .map(|mut s| { + s.rename(name_generator(i).clone()); + s + }) }) .collect::>>() })?; - StructChunked::from_series(ca.name(), &fields) + StructChunked::from_series(ca.name().clone(), &fields) } } diff --git a/crates/polars-ops/src/chunked_array/binary/namespace.rs b/crates/polars-ops/src/chunked_array/binary/namespace.rs index 6e4a29e86874..487f6a11f0df 100644 --- a/crates/polars-ops/src/chunked_array/binary/namespace.rs +++ b/crates/polars-ops/src/chunked_array/binary/namespace.rs @@ -24,7 +24,7 @@ pub trait BinaryNameSpaceImpl: AsBinary { match lit.len() { 1 => match lit.get(0) { Some(lit) => ca.contains(lit), - None => BooleanChunked::full_null(ca.name(), ca.len()), + None => BooleanChunked::full_null(ca.name().clone(), ca.len()), }, _ => broadcast_binary_elementwise_values(ca, lit, |src, lit| find(src, lit).is_some()), } @@ -35,7 +35,7 @@ pub trait BinaryNameSpaceImpl: AsBinary { let ca = self.as_binary(); let f = |s: &[u8]| s.ends_with(sub); let mut out: BooleanChunked = ca.into_iter().map(|opt_s| opt_s.map(f)).collect(); - out.rename(ca.name()); + out.rename(ca.name().clone()); out } @@ -44,7 +44,7 @@ pub trait BinaryNameSpaceImpl: AsBinary { let ca = self.as_binary(); let f = |s: &[u8]| s.starts_with(sub); let mut out: BooleanChunked = ca.into_iter().map(|opt_s| opt_s.map(f)).collect(); - out.rename(ca.name()); + out.rename(ca.name().clone()); out } @@ -53,7 +53,7 @@ pub trait BinaryNameSpaceImpl: AsBinary { match prefix.len() { 1 => match prefix.get(0) { Some(s) => self.starts_with(s), - None => BooleanChunked::full_null(ca.name(), ca.len()), + None => BooleanChunked::full_null(ca.name().clone(), ca.len()), }, _ => broadcast_binary_elementwise_values(ca, prefix, |s, sub| s.starts_with(sub)), } @@ -64,7 +64,7 @@ pub trait BinaryNameSpaceImpl: AsBinary { match suffix.len() { 1 => match suffix.get(0) { Some(s) => self.ends_with(s), - None => BooleanChunked::full_null(ca.name(), ca.len()), + None => BooleanChunked::full_null(ca.name().clone(), ca.len()), }, _ => broadcast_binary_elementwise_values(ca, suffix, |s, sub| s.ends_with(sub)), } diff --git a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs index a84bef3d1534..1637dd392707 100644 --- a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs +++ b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs @@ -25,7 +25,7 @@ pub fn replace_time_zone( let mut out = datetime .0 .clone() - .into_datetime(datetime.time_unit(), time_zone.map(|x| x.to_string())); + .into_datetime(datetime.time_unit(), time_zone.map(PlSmallStr::from_str)); out.set_sorted_flag(datetime.is_sorted_flag()); return Ok(out); } @@ -64,7 +64,7 @@ pub fn replace_time_zone( ) }; - let mut out = out?.into_datetime(datetime.time_unit(), time_zone.map(|x| x.to_string())); + let mut out = out?.into_datetime(datetime.time_unit(), time_zone.map(PlSmallStr::from_str)); if from_time_zone == "UTC" && ambiguous.len() == 1 && ambiguous.get(0) == Some("raise") { // In general, the sortedness flag can't be preserved. // To be safe, we only do so in the simplest case when we know for sure that there is no "daylight savings weirdness" going on, i.e.: @@ -131,7 +131,7 @@ pub fn impl_replace_time_zone( }); element_iter.try_collect_arr() }); - ChunkedArray::try_from_chunk_iter(datetime.0.name(), iter) + ChunkedArray::try_from_chunk_iter(datetime.0.name().clone(), iter) }, _ => try_binary_elementwise(datetime, ambiguous, |timestamp_opt, ambiguous_opt| { match (timestamp_opt, ambiguous_opt) { diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index e22a9c935176..345f3689984c 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -140,7 +140,7 @@ impl TakeChunked for Series { out.into_decimal_unchecked(ca.precision(), ca.scale()) .into_series() }, - Null => Series::new_null(self.name(), by.len()), + Null => Series::new_null(self.name().clone(), by.len()), _ => unreachable!(), }; unsafe { out.cast_unchecked(self.dtype()).unwrap() } @@ -197,7 +197,7 @@ impl TakeChunked for Series { out.into_decimal_unchecked(ca.precision(), ca.scale()) .into_series() }, - Null => Series::new_null(self.name(), by.len()), + Null => Series::new_null(self.name().clone(), by.len()), _ => unreachable!(), }; unsafe { out.cast_unchecked(self.dtype()).unwrap() } @@ -225,7 +225,7 @@ where }); let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) } else { let targets = self.downcast_iter().collect::>(); let iter = by.iter().map(|chunk_id| { @@ -238,7 +238,7 @@ where vals.get_unchecked(array_idx as usize) }); let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) }; let sorted_flag = _update_gather_sorted_flag(self.is_sorted_flag(), sorted); out.set_sorted_flag(sorted_flag); @@ -264,7 +264,7 @@ where }) .collect_arr_trusted_with_dtype(arrow_dtype); - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) } else { let targets = self.downcast_iter().collect::>(); let arr = by @@ -280,7 +280,7 @@ where }) .collect_arr_trusted_with_dtype(arrow_dtype); - ChunkedArray::with_chunk(self.name(), arr) + ChunkedArray::with_chunk(self.name().clone(), arr) } } } @@ -291,7 +291,7 @@ unsafe fn take_unchecked_object(s: &Series, by: &[ChunkId], _sorted: IsSorted) - unreachable!() }; let reg = reg.as_ref().unwrap(); - let mut builder = (*reg.builder_constructor)(s.name(), by.len()); + let mut builder = (*reg.builder_constructor)(s.name().clone(), by.len()); by.iter().for_each(|chunk_id| { let (chunk_idx, array_idx) = chunk_id.extract(); @@ -307,7 +307,7 @@ unsafe fn take_opt_unchecked_object(s: &Series, by: &[NullableChunkId]) -> Serie unreachable!() }; let reg = reg.as_ref().unwrap(); - let mut builder = (*reg.builder_constructor)(s.name(), by.len()); + let mut builder = (*reg.builder_constructor)(s.name().clone(), by.len()); by.iter().for_each(|chunk_id| { if chunk_id.is_null() { @@ -409,7 +409,7 @@ unsafe fn take_unchecked_binview( ) .maybe_gc(); - let mut out = BinaryChunked::with_chunk(ca.name(), arr); + let mut out = BinaryChunked::with_chunk(ca.name().clone(), arr); let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), sorted); out.set_sorted_flag(sorted_flag); out @@ -485,7 +485,7 @@ unsafe fn take_unchecked_binview_opt(ca: &BinaryChunked, by: &[NullableChunkId]) ) .maybe_gc(); - BinaryChunked::with_chunk(ca.name(), arr) + BinaryChunked::with_chunk(ca.name().clone(), arr) } #[cfg(test)] @@ -497,15 +497,15 @@ mod test { unsafe { // # Series without nulls; let mut s_1 = Series::new( - "a", + "a".into(), &["1 loooooooooooong string", "2 loooooooooooong string"], ); let s_2 = Series::new( - "a", + "a".into(), &["11 loooooooooooong string", "22 loooooooooooong string"], ); let s_3 = Series::new( - "a", + "a".into(), &[ "111 loooooooooooong string", "222 loooooooooooong string", @@ -529,7 +529,7 @@ mod test { ]; let out = s_1.take_chunked_unchecked(&by, IsSorted::Not); - let idx = IdxCa::new("", [0, 1, 3, 2, 4, 5, 6]); + let idx = IdxCa::new("".into(), [0, 1, 3, 2, 4, 5, 6]); let expected = s_1.rechunk().take(&idx).unwrap(); assert!(out.equals(&expected)); @@ -542,16 +542,16 @@ mod test { ]; let out = s_1.take_opt_chunked_unchecked(&by); - let idx = IdxCa::new("", [None, Some(1), Some(3), Some(2)]); + let idx = IdxCa::new("".into(), [None, Some(1), Some(3), Some(2)]); let expected = s_1.rechunk().take(&idx).unwrap(); assert!(out.equals_missing(&expected)); // # Series with nulls; let mut s_1 = Series::new( - "a", + "a".into(), &["1 loooooooooooong string 1", "2 loooooooooooong string 2"], ); - let s_2 = Series::new("a", &[Some("11 loooooooooooong string 11"), None]); + let s_2 = Series::new("a".into(), &[Some("11 loooooooooooong string 11"), None]); s_1.append(&s_2).unwrap(); // ## Ids without nulls; @@ -563,7 +563,7 @@ mod test { ]; let out = s_1.take_chunked_unchecked(&by, IsSorted::Not); - let idx = IdxCa::new("", [0, 1, 3, 2]); + let idx = IdxCa::new("".into(), [0, 1, 3, 2]); let expected = s_1.rechunk().take(&idx).unwrap(); assert!(out.equals_missing(&expected)); @@ -576,7 +576,7 @@ mod test { ]; let out = s_1.take_opt_chunked_unchecked(&by); - let idx = IdxCa::new("", [None, Some(1), Some(3), Some(2)]); + let idx = IdxCa::new("".into(), [None, Some(1), Some(3), Some(2)]); let expected = s_1.rechunk().take(&idx).unwrap(); assert!(out.equals_missing(&expected)); } diff --git a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs index ff52a6601589..5101d3668137 100644 --- a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs +++ b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs @@ -213,10 +213,10 @@ mod test { let idx_chunks: Vec<_> = (0..num_idx_chunks).map(|_| random_vec(&mut rng, 0..num_nonnull_elems as IdxSize, 0..200)).collect(); let null_idx_chunks: Vec<_> = idx_chunks.iter().map(|c| random_filter(&mut rng, c, 0.7..1.0)).collect(); - let nonnull_ca = UInt32Chunked::from_chunk_iter("", elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); - let ca = UInt32Chunked::from_chunk_iter("", null_elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); - let nonnull_idx_ca = IdxCa::from_chunk_iter("", idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); - let idx_ca = IdxCa::from_chunk_iter("", null_idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let nonnull_ca = UInt32Chunked::from_chunk_iter("".into(), elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let ca = UInt32Chunked::from_chunk_iter("".into(), null_elem_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let nonnull_idx_ca = IdxCa::from_chunk_iter("".into(), idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); + let idx_ca = IdxCa::from_chunk_iter("".into(), null_idx_chunks.iter().cloned().map(|v| v.into_iter().collect_arr())); gather_skip_nulls_check(&ca, &idx_ca); gather_skip_nulls_check(&ca, &nonnull_idx_ca); diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index d2a0acc76239..8d7781745531 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -3,7 +3,6 @@ use std::fmt::Write; use num_traits::ToPrimitive; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; -use polars_utils::float::IsFloat; use polars_utils::total_ord::ToTotalOrd; fn compute_hist( @@ -17,6 +16,7 @@ where T: PolarsNumericType, ChunkedArray: ChunkAgg, { + let mut lower_bound: f64; let (breaks, count) = if let Some(bins) = bins { let mut breaks = Vec::with_capacity(bins.len() + 1); breaks.extend_from_slice(bins); @@ -31,7 +31,7 @@ where // We start with the lower garbage bin. // (-inf, B0] - let mut lower_bound = f64::NEG_INFINITY; + lower_bound = f64::NEG_INFINITY; let mut upper_bound = *breaks_iter.next().unwrap(); for chunk in sorted.downcast_iter() { @@ -60,17 +60,17 @@ where while count.len() < breaks.len() { count.push(0) } + // Push lower bound to infinity + lower_bound = f64::NEG_INFINITY; (breaks, count) } else if ca.null_count() == ca.len() { + lower_bound = f64::NEG_INFINITY; let breaks: Vec = vec![f64::INFINITY]; let count: Vec = vec![0]; (breaks, count) } else { - let min = ChunkAgg::min(ca).unwrap().to_f64().unwrap(); - let max = ChunkAgg::max(ca).unwrap().to_f64().unwrap(); - - let start = min.floor() - 1.0; - let end = max.ceil() + 1.0; + let start = ChunkAgg::min(ca).unwrap().to_f64().unwrap(); + let end = ChunkAgg::max(ca).unwrap().to_f64().unwrap(); // If bin_count is omitted, default to the difference between start and stop (unit bins) let bin_count = if let Some(bin_count) = bin_count { @@ -79,37 +79,24 @@ where (end - start).round() as usize }; - // Calculate the breakpoints and make the array + // Calculate the breakpoints and make the array. The breakpoints form the RHS of the bins. let interval = (end - start) / (bin_count as f64); - - let breaks_iter = (0..(bin_count)).map(|b| start + (b as f64) * interval); - + let breaks_iter = (1..(bin_count)).map(|b| start + (b as f64) * interval); let mut breaks = Vec::with_capacity(breaks_iter.size_hint().0 + 1); breaks.extend(breaks_iter); - breaks.push(f64::INFINITY); - let mut count: Vec = vec![0; breaks.len()]; - let end_idx = count.len() - 1; + // Extend the left-most edge by 0.1% of the total range to include the minimum value. + let margin = (end - start) * 0.001; + lower_bound = start - margin; + breaks.push(end); - // start is the closed rhs of the interval, so we subtract the bucket width - let start_range = start - interval; + let mut count: Vec = vec![0; bin_count]; + let max_bin = breaks.len() - 1; for chunk in ca.downcast_iter() { for item in chunk.non_null_values_iter() { - let item = item.to_f64().unwrap() - start_range; - - // This is needed for numeric stability. - // Only for integers. - // we can fall directly on a boundary with an integer. - let item = item / interval; - let item = if !T::Native::is_float() && (item.round() - item).abs() < 0.0000001 { - item.round() - 1.0 - } else { - item.ceil() - 1.0 - }; - - let idx = item as usize; - let idx = std::cmp::min(idx, end_idx); - count[idx] += 1; + let item = item.to_f64().unwrap(); + let bin = ((((item - start) / interval).ceil() - 1.0) as usize).min(max_bin); + count[bin] += 1; } } (breaks, count) @@ -117,8 +104,9 @@ where let mut fields = Vec::with_capacity(3); if include_category { // Use AnyValue for formatting. - let mut lower = AnyValue::Float64(f64::NEG_INFINITY); - let mut categories = StringChunkedBuilder::new("category", breaks.len()); + let mut lower = AnyValue::Float64(lower_bound); + let mut categories = + StringChunkedBuilder::new(PlSmallStr::from_static("category"), breaks.len()); let mut buf = String::new(); for br in &breaks { @@ -135,17 +123,20 @@ where fields.push(categories); }; if include_breakpoint { - fields.insert(0, Series::new("breakpoint", breaks)) + fields.insert( + 0, + Series::new(PlSmallStr::from_static("breakpoint"), breaks), + ) } - let count = Series::new("count", count); + let count = Series::new(PlSmallStr::from_static("count"), count); fields.push(count); if fields.len() == 1 { let out = fields.pop().unwrap(); - out.with_name(ca.name()) + out.with_name(ca.name().clone()) } else { - StructChunked::from_series(ca.name(), &fields) + StructChunked::from_series(ca.name().clone(), &fields) .unwrap() .into_series() } diff --git a/crates/polars-ops/src/chunked_array/list/any_all.rs b/crates/polars-ops/src/chunked_array/list/any_all.rs index 1364a872b133..a8727bb3082a 100644 --- a/crates/polars-ops/src/chunked_array/list/any_all.rs +++ b/crates/polars-ops/src/chunked_array/list/any_all.rs @@ -10,7 +10,7 @@ where let offsets = arr.offsets().as_slice(); let values = arr.values(); - polars_ensure!(values.data_type() == &ArrowDataType::Boolean, ComputeError: "expected boolean elements in list"); + polars_ensure!(values.dtype() == &ArrowDataType::Boolean, ComputeError: "expected boolean elements in list"); let values = values.as_any().downcast_ref::().unwrap(); let validity = arr.validity().cloned(); @@ -41,12 +41,12 @@ pub(super) fn list_all(ca: &ListChunked) -> PolarsResult { let chunks = ca .downcast_iter() .map(|arr| list_all_any(arr, arrow::compute::boolean::all, true)); - Ok(BooleanChunked::try_from_chunk_iter(ca.name(), chunks)?.into_series()) + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) } pub(super) fn list_any(ca: &ListChunked) -> PolarsResult { let chunks = ca .downcast_iter() .map(|arr| list_all_any(arr, arrow::compute::boolean::any, false)); - Ok(BooleanChunked::try_from_chunk_iter(ca.name(), chunks)?.into_series()) + Ok(BooleanChunked::try_from_chunk_iter(ca.name().clone(), chunks)?.into_series()) } diff --git a/crates/polars-ops/src/chunked_array/list/count.rs b/crates/polars-ops/src/chunked_array/list/count.rs index 4c562f1d1072..e54c603f3a25 100644 --- a/crates/polars-ops/src/chunked_array/list/count.rs +++ b/crates/polars-ops/src/chunked_array/list/count.rs @@ -42,7 +42,7 @@ fn count_bits_set_by_offsets(values: &Bitmap, offset: &[i64]) -> Vec { #[cfg(feature = "list_count")] pub fn list_count_matches(ca: &ListChunked, value: AnyValue) -> PolarsResult { - let value = Series::new("", [value]); + let value = Series::new(PlSmallStr::EMPTY, [value]); let ca = ca.apply_to_inner(&|s| { ChunkCompare::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) @@ -59,5 +59,5 @@ pub(super) fn count_boolean_bits(ca: &ListChunked) -> IdxCa { let out = count_bits_set_by_offsets(mask.values(), arr.offsets().as_slice()); IdxArr::from_data_default(out.into(), arr.validity().cloned()) }); - IdxCa::from_chunk_iter(ca.name(), chunks) + IdxCa::from_chunk_iter(ca.name().clone(), chunks) } diff --git a/crates/polars-ops/src/chunked_array/list/dispersion.rs b/crates/polars-ops/src/chunked_array/list/dispersion.rs index 76c4075f265b..2796ebb1de9e 100644 --- a/crates/polars-ops/src/chunked_array/list/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/list/dispersion.rs @@ -5,20 +5,20 @@ pub(super) fn median_with_nulls(ca: &ListChunked) -> Series { DataType::Float32 => { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median().map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(*tu).into_series() }, _ => { let out: Float64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().median())) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, }; @@ -29,20 +29,20 @@ pub(super) fn std_with_nulls(ca: &ListChunked, ddof: u8) -> Series { DataType::Float32 => { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, #[cfg(feature = "dtype-duration")] DataType::Duration(tu) => { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof).map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(*tu).into_series() }, _ => { let out: Float64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().std(ddof))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, }; @@ -53,14 +53,14 @@ pub(super) fn var_with_nulls(ca: &ListChunked, ddof: u8) -> Series { DataType::Float32 => { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, #[cfg(feature = "dtype-duration")] DataType::Duration(TimeUnit::Milliseconds) => { let out: Int64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(TimeUnit::Milliseconds).into_series() }, #[cfg(feature = "dtype-duration")] @@ -73,13 +73,13 @@ pub(super) fn var_with_nulls(ca: &ListChunked, ddof: u8) -> Series { .list() .unwrap() .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof).map(|v| v as i64))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_duration(TimeUnit::Milliseconds).into_series() }, _ => { let out: Float64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().var(ddof))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, }; diff --git a/crates/polars-ops/src/chunked_array/list/hash.rs b/crates/polars-ops/src/chunked_array/list/hash.rs index 67cb61a51273..0c567c729041 100644 --- a/crates/polars-ops/src/chunked_array/list/hash.rs +++ b/crates/polars-ops/src/chunked_array/list/hash.rs @@ -9,7 +9,7 @@ use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; -fn hash_agg(ca: &ChunkedArray, random_state: &ahash::RandomState) -> u64 +fn hash_agg(ca: &ChunkedArray, random_state: &PlRandomState) -> u64 where T: PolarsNumericType, T::Native: TotalHash + ToTotalOrd, @@ -44,7 +44,7 @@ where hash_agg } -pub(crate) fn hash(ca: &mut ListChunked, build_hasher: ahash::RandomState) -> UInt64Chunked { +pub(crate) fn hash(ca: &mut ListChunked, build_hasher: PlRandomState) -> UInt64Chunked { if !ca.inner_dtype().to_physical().is_numeric() { panic!( "Hashing a list with a non-numeric inner type not supported. Got dtype: {:?}", @@ -80,6 +80,6 @@ pub(crate) fn hash(ca: &mut ListChunked, build_hasher: ahash::RandomState) -> UI }); let mut out = out.into_inner(); - out.rename(ca.name()); + out.rename(ca.name().clone()); out } diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index 10f275f32183..8d3a4d1d4197 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -66,7 +66,7 @@ fn min_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { }) .collect::>(); - Series::try_from((ca.name(), chunks)).unwrap() + Series::try_from((ca.name().clone(), chunks)).unwrap() } pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { @@ -92,7 +92,7 @@ pub(super) fn list_min_function(ca: &ListChunked) -> PolarsResult { .try_apply_amortized(|s| { let s = s.as_ref(); let sc = s.min_reduce()?; - Ok(sc.into_series(s.name())) + Ok(sc.into_series(s.name().clone())) })? .explode() .unwrap() @@ -175,7 +175,7 @@ fn max_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { }) .collect::>(); - Series::try_from((ca.name(), chunks)).unwrap() + Series::try_from((ca.name().clone(), chunks)).unwrap() } pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { @@ -202,7 +202,7 @@ pub(super) fn list_max_function(ca: &ListChunked) -> PolarsResult { .try_apply_amortized(|s| { let s = s.as_ref(); let sc = s.max_reduce()?; - Ok(sc.into_series(s.name())) + Ok(sc.into_series(s.name().clone())) })? .explode() .unwrap() diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 0306375af35f..0c7a0975488c 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -87,7 +87,7 @@ pub trait ListNameSpaceImpl: AsList { DataType::String => match separator.len() { 1 => match separator.get(0) { Some(separator) => self.join_literal(separator, ignore_nulls), - _ => Ok(StringChunked::full_null(ca.name(), ca.len())), + _ => Ok(StringChunked::full_null(ca.name().clone(), ca.len())), }, _ => self.join_many(separator, ignore_nulls), }, @@ -99,7 +99,7 @@ pub trait ListNameSpaceImpl: AsList { let ca = self.as_list(); // used to amortize heap allocs let mut buf = String::with_capacity(128); - let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); ca.for_each_amortized(|opt_s| { let opt_val = opt_s.and_then(|s| { @@ -135,7 +135,7 @@ pub trait ListNameSpaceImpl: AsList { let ca = self.as_list(); // used to amortize heap allocs let mut buf = String::with_capacity(128); - let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); + let mut builder = StringChunkedBuilder::new(ca.name().clone(), ca.len()); { ca.amortized_iter() .zip(separator) @@ -303,7 +303,7 @@ pub trait ListNameSpaceImpl: AsList { if let Some(periods) = periods.get(0) { ca.apply_amortized(|s| s.as_ref().shift(periods)) } else { - ListChunked::full_null_with_dtype(ca.name(), ca.len(), ca.inner_dtype()) + ListChunked::full_null_with_dtype(ca.name().clone(), ca.len(), ca.inner_dtype()) } }, _ => ca.zip_and_apply_amortized(periods, |opt_s, opt_periods| { @@ -333,7 +333,7 @@ pub trait ListNameSpaceImpl: AsList { last = *o; } }); - IdxCa::from_vec(ca.name(), lengths) + IdxCa::from_vec(ca.name().clone(), lengths) } /// Get the value by index in the sublists. @@ -352,7 +352,7 @@ pub trait ListNameSpaceImpl: AsList { .collect::>(); // SAFETY: every element in list has dtype equal to its inner type unsafe { - Series::try_from((ca.name(), chunks)) + Series::try_from((ca.name().clone(), chunks)) .unwrap() .cast_unchecked(ca.inner_dtype()) } @@ -366,7 +366,7 @@ pub trait ListNameSpaceImpl: AsList { (Some(n), Some(offset)) => list_ca .apply_amortized(|s| s.as_ref().gather_every(n as usize, offset as usize)), _ => ListChunked::full_null_with_dtype( - list_ca.name(), + list_ca.name().clone(), list_ca.len(), list_ca.inner_dtype(), ), @@ -383,7 +383,7 @@ pub trait ListNameSpaceImpl: AsList { }) } else { ListChunked::full_null_with_dtype( - list_ca.name(), + list_ca.name().clone(), list_ca.len(), list_ca.inner_dtype(), ) @@ -399,7 +399,7 @@ pub trait ListNameSpaceImpl: AsList { }) } else { ListChunked::full_null_with_dtype( - list_ca.name(), + list_ca.name().clone(), list_ca.len(), list_ca.inner_dtype(), ) @@ -439,7 +439,7 @@ pub trait ListNameSpaceImpl: AsList { }) .collect::>() .map(|mut ca| { - ca.rename(list_ca.name()); + ca.rename(list_ca.name().clone()); ca.into_series() }) } @@ -447,7 +447,7 @@ pub trait ListNameSpaceImpl: AsList { use DataType::*; match idx.dtype() { - List(_) => { + List(boxed_dt) if boxed_dt.is_integer() => { let idx_ca = idx.list().unwrap(); let mut out = { list_ca @@ -466,7 +466,7 @@ pub trait ListNameSpaceImpl: AsList { }) .collect::>()? }; - out.rename(list_ca.name()); + out.rename(list_ca.name().clone()); Ok(out.into_series()) }, @@ -486,7 +486,7 @@ pub trait ListNameSpaceImpl: AsList { }) .collect::>()? }; - out.rename(list_ca.name()); + out.rename(list_ca.name().clone()); Ok(out.into_series()) } } else { @@ -526,7 +526,7 @@ pub trait ListNameSpaceImpl: AsList { }) } else { Ok(ListChunked::full_null_with_dtype( - ca.name(), + ca.name().clone(), ca.len(), ca.inner_dtype(), )) @@ -565,7 +565,7 @@ pub trait ListNameSpaceImpl: AsList { }) } else { Ok(ListChunked::full_null_with_dtype( - ca.name(), + ca.name().clone(), ca.len(), ca.inner_dtype(), )) @@ -635,7 +635,7 @@ pub trait ListNameSpaceImpl: AsList { // there was a None, so all values will be None if to_append.len() != other_len { return Ok(ListChunked::full_null_with_dtype( - ca.name(), + ca.name().clone(), length, &inner_super_type, )); @@ -650,7 +650,7 @@ pub trait ListNameSpaceImpl: AsList { &inner_super_type, ca.get_values_size() + vals_size_other + 1, length, - ca.name(), + ca.name().clone(), )?; ca.into_iter().for_each(|opt_s| { let opt_s = opt_s.map(|mut s| { @@ -687,7 +687,7 @@ pub trait ListNameSpaceImpl: AsList { &inner_super_type, ca.get_values_size() + vals_size_other + 1, length, - ca.name(), + ca.name().clone(), )?; for _ in 0..ca.len() { diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index 4a6f1f0466b4..e105d96b737a 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -251,7 +251,7 @@ where offsets.push(offset as i64); } let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; - let dtype = ListArray::::default_datatype(values_out.data_type().clone()); + let dtype = ListArray::::default_datatype(values_out.dtype().clone()); let values: PrimitiveArray = values_out.into(); Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) @@ -346,10 +346,10 @@ fn binary( if as_utf8 { let values = unsafe { values.to_utf8view_unchecked() }; - let dtype = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) } else { - let dtype = ListArray::::default_datatype(values.data_type().clone()); + let dtype = ListArray::::default_datatype(values.dtype().clone()); Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) } } @@ -364,9 +364,9 @@ fn array_set_operation( let values_a = a.values(); let values_b = b.values(); - assert_eq!(values_a.data_type(), values_b.data_type()); + assert_eq!(values_a.dtype(), values_b.dtype()); - let dtype = values_b.data_type(); + let dtype = values_b.dtype(); let validity = combine_validities_and(a.validity(), b.validity()); match dtype { diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index edbe584c436a..87dff648b9b2 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -62,7 +62,7 @@ pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Ser }) .collect::>(); - Series::try_from((ca.name(), chunks)).unwrap() + Series::try_from((ca.name().clone(), chunks)).unwrap() } pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> PolarsResult { @@ -106,12 +106,16 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars }, // slowest sum_as_series path _ => ca - .try_apply_amortized(|s| s.as_ref().sum_reduce().map(|sc| sc.into_series("")))? + .try_apply_amortized(|s| { + s.as_ref() + .sum_reduce() + .map(|sc| sc.into_series(PlSmallStr::EMPTY)) + })? .explode() .unwrap() .into_series(), }; - out.rename(ca.name()); + out.rename(ca.name().clone()); Ok(out) } @@ -167,7 +171,7 @@ pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Se }) .collect::>(); - Series::try_from((ca.name(), chunks)).unwrap() + Series::try_from((ca.name().clone(), chunks)).unwrap() } pub(super) fn mean_with_nulls(ca: &ListChunked) -> Series { @@ -175,13 +179,13 @@ pub(super) fn mean_with_nulls(ca: &ListChunked) -> Series { DataType::Float32 => { let out: Float32Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32))) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, _ => { let out: Float64Chunked = ca .apply_amortized_generic(|s| s.and_then(|s| s.as_ref().mean())) - .with_name(ca.name()); + .with_name(ca.name().clone()); out.into_series() }, }; diff --git a/crates/polars-ops/src/chunked_array/list/to_struct.rs b/crates/polars-ops/src/chunked_array/list/to_struct.rs index 2f887c69e020..73798163ed48 100644 --- a/crates/polars-ops/src/chunked_array/list/to_struct.rs +++ b/crates/polars-ops/src/chunked_array/list/to_struct.rs @@ -1,7 +1,7 @@ use polars_core::export::rayon::prelude::*; use polars_core::POOL; -use polars_utils::format_smartstring; -use smartstring::alias::String as SmartString; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; use super::*; @@ -48,10 +48,10 @@ fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize } } -pub type NameGenerator = Arc SmartString + Send + Sync>; +pub type NameGenerator = Arc PlSmallStr + Send + Sync>; -pub fn _default_struct_name_gen(idx: usize) -> SmartString { - format_smartstring!("field_{idx}") +pub fn _default_struct_name_gen(idx: usize) -> PlSmallStr { + format_pl_smallstr!("field_{idx}") } pub trait ToStruct: AsList { @@ -73,14 +73,14 @@ pub trait ToStruct: AsList { .into_par_iter() .map(|i| { ca.lst_get(i as i64, true).map(|mut s| { - s.rename(&name_generator(i)); + s.rename(name_generator(i)); s }) }) .collect::>>() })?; - StructChunked::from_series(ca.name(), &fields) + StructChunked::from_series(ca.name().clone(), &fields) } } diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs index 26b728306c5e..a36b161775ca 100644 --- a/crates/polars-ops/src/chunked_array/mode.rs +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -89,31 +89,32 @@ mod test { #[test] fn mode_test() { - let ca = Int32Chunked::from_slice("test", &[0, 1, 2, 3, 4, 4, 5, 6, 5, 0]); + let ca = Int32Chunked::from_slice("test".into(), &[0, 1, 2, 3, 4, 4, 5, 6, 5, 0]); let mut result = mode_primitive(&ca).unwrap().to_vec(); result.sort_by_key(|a| a.unwrap()); assert_eq!(&result, &[Some(0), Some(4), Some(5)]); - let ca = Int32Chunked::from_slice("test", &[1, 1]); + let ca = Int32Chunked::from_slice("test".into(), &[1, 1]); let mut result = mode_primitive(&ca).unwrap().to_vec(); result.sort_by_key(|a| a.unwrap()); assert_eq!(&result, &[Some(1)]); - let ca = Int32Chunked::from_slice("test", &[]); + let ca = Int32Chunked::from_slice("test".into(), &[]); let mut result = mode_primitive(&ca).unwrap().to_vec(); result.sort_by_key(|a| a.unwrap()); assert_eq!(result, &[]); - let ca = Float32Chunked::from_slice("test", &[1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0]); + let ca = Float32Chunked::from_slice("test".into(), &[1.0f32, 2.0, 2.0, 3.0, 3.0, 3.0]); let result = mode_primitive(&ca).unwrap().to_vec(); assert_eq!(result, &[Some(3.0f32)]); - let ca = StringChunked::from_slice("test", &["test", "test", "test", "another test"]); + let ca = + StringChunked::from_slice("test".into(), &["test", "test", "test", "another test"]); let result = mode_primitive(&ca).unwrap(); let vec_result4: Vec> = result.into_iter().collect(); assert_eq!(vec_result4, &[Some("test")]); - let mut ca_builder = CategoricalChunkedBuilder::new("test", 5, Default::default()); + let mut ca_builder = CategoricalChunkedBuilder::new("test".into(), 5, Default::default()); ca_builder.append_value("test"); ca_builder.append_value("test"); ca_builder.append_value("test2"); diff --git a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs index 6c811ccbbf0f..ec1d8b2c9d4f 100644 --- a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs +++ b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs @@ -30,7 +30,7 @@ where .reduce(min_or_max_fn) } -pub fn nan_min_s(s: &Series, name: &str) -> Series { +pub fn nan_min_s(s: &Series, name: PlSmallStr) -> Series { match s.dtype() { DataType::Float32 => { let ca = s.f32().unwrap(); @@ -44,7 +44,7 @@ pub fn nan_min_s(s: &Series, name: &str) -> Series { } } -pub fn nan_max_s(s: &Series, name: &str) -> Series { +pub fn nan_max_s(s: &Series, name: PlSmallStr) -> Series { match s.dtype() { DataType::Float32 => { let ca = s.f32().unwrap(); diff --git a/crates/polars-ops/src/chunked_array/repeat_by.rs b/crates/polars-ops/src/chunked_array/repeat_by.rs index 8ccf9ae58141..0adec3154275 100644 --- a/crates/polars-ops/src/chunked_array/repeat_by.rs +++ b/crates/polars-ops/src/chunked_array/repeat_by.rs @@ -15,7 +15,7 @@ fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> { fn new_by(by: &IdxCa, len: usize) -> IdxCa { IdxCa::new( - "", + PlSmallStr::EMPTY, std::iter::repeat(by.get(0).unwrap()) .take(len) .collect::>(), diff --git a/crates/polars-ops/src/chunked_array/scatter.rs b/crates/polars-ops/src/chunked_array/scatter.rs index 6e535ea60480..820989c294fe 100644 --- a/crates/polars-ops/src/chunked_array/scatter.rs +++ b/crates/polars-ops/src/chunked_array/scatter.rs @@ -143,7 +143,7 @@ impl<'a> ChunkedSet<&'a str> for &'a StringChunked { check_bounds(idx, self.len() as IdxSize)?; check_sorted(idx)?; let mut ca_iter = self.into_iter().enumerate(); - let mut builder = StringChunkedBuilder::new(self.name(), self.len()); + let mut builder = StringChunkedBuilder::new(self.name().clone(), self.len()); for (current_idx, current_value) in idx.iter().zip(values) { for (cnt_idx, opt_val_self) in &mut ca_iter { @@ -172,7 +172,7 @@ impl ChunkedSet for &BooleanChunked { check_bounds(idx, self.len() as IdxSize)?; check_sorted(idx)?; let mut ca_iter = self.into_iter().enumerate(); - let mut builder = BooleanChunkedBuilder::new(self.name(), self.len()); + let mut builder = BooleanChunkedBuilder::new(self.name().clone(), self.len()); for (current_idx, current_value) in idx.iter().zip(values) { for (cnt_idx, opt_val_self) in &mut ca_iter { diff --git a/crates/polars-ops/src/chunked_array/strings/case.rs b/crates/polars-ops/src/chunked_array/strings/case.rs index 8e8482eb1291..7bb348e28803 100644 --- a/crates/polars-ops/src/chunked_array/strings/case.rs +++ b/crates/polars-ops/src/chunked_array/strings/case.rs @@ -154,7 +154,7 @@ pub(super) fn to_titlecase<'a>(ca: &'a StringChunked) -> StringChunked { } else { s.push(c); } - next_is_upper = c.is_whitespace(); + next_is_upper = !c.is_alphanumeric(); } // Put buf back for next iteration. diff --git a/crates/polars-ops/src/chunked_array/strings/concat.rs b/crates/polars-ops/src/chunked_array/strings/concat.rs index 67d1f244843d..490cbc2c85c6 100644 --- a/crates/polars-ops/src/chunked_array/strings/concat.rs +++ b/crates/polars-ops/src/chunked_array/strings/concat.rs @@ -6,17 +6,17 @@ use polars_core::prelude::*; // Vertically concatenate all strings in a StringChunked. pub fn str_join(ca: &StringChunked, delimiter: &str, ignore_nulls: bool) -> StringChunked { if ca.is_empty() { - return StringChunked::new(ca.name(), &[""]); + return StringChunked::new(ca.name().clone(), &[""]); } // Propagate null value. if !ignore_nulls && ca.null_count() != 0 { - return StringChunked::full_null(ca.name(), 1); + return StringChunked::full_null(ca.name().clone(), 1); } // Fast path for all nulls. if ignore_nulls && ca.null_count() == ca.len() { - return StringChunked::new(ca.name(), &[""]); + return StringChunked::new(ca.name().clone(), &[""]); } if ca.len() == 1 { @@ -44,7 +44,7 @@ pub fn str_join(ca: &StringChunked, delimiter: &str, ignore_nulls: bool) -> Stri let arr = unsafe { Utf8Array::from_data_unchecked_default(offsets.into(), buf.into(), None) }; // conversion is cheap with one value. let arr = utf8_to_utf8view(&arr); - StringChunked::with_chunk(ca.name(), arr) + StringChunked::with_chunk(ca.name().clone(), arr) } enum ColumnIter { @@ -61,7 +61,7 @@ pub fn hor_str_concat( ignore_nulls: bool, ) -> PolarsResult { if cas.is_empty() { - return Ok(StringChunked::full_null("", 0)); + return Ok(StringChunked::full_null(PlSmallStr::EMPTY, 0)); } if cas.len() == 1 { let ca = cas[0]; @@ -84,7 +84,7 @@ pub fn hor_str_concat( ComputeError: "all series in `hor_str_concat` should have equal or unit length" ); - let mut builder = StringChunkedBuilder::new(cas[0].name(), len); + let mut builder = StringChunkedBuilder::new(cas[0].name().clone(), len); // Broadcast if appropriate. let mut cols: Vec<_> = cas @@ -141,7 +141,7 @@ mod test { #[test] fn test_str_concat() { - let ca = Int32Chunked::new("foo", &[Some(1), None, Some(3)]); + let ca = Int32Chunked::new("foo".into(), &[Some(1), None, Some(3)]); let ca_str = ca.cast(&DataType::String).unwrap(); let out = str_join(ca_str.str().unwrap(), "-", true); @@ -151,13 +151,13 @@ mod test { #[test] fn test_hor_str_concat() { - let a = StringChunked::new("a", &["foo", "bar"]); - let b = StringChunked::new("b", &["spam", "ham"]); + let a = StringChunked::new("a".into(), &["foo", "bar"]); + let b = StringChunked::new("b".into(), &["spam", "ham"]); let out = hor_str_concat(&[&a, &b], "_", true).unwrap(); assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]); - let c = StringChunked::new("b", &["literal"]); + let c = StringChunked::new("b".into(), &["literal"]); let out = hor_str_concat(&[&a, &b, &c], "_", true).unwrap(); assert_eq!( Vec::from(&out), diff --git a/crates/polars-ops/src/chunked_array/strings/extract.rs b/crates/polars-ops/src/chunked_array/strings/extract.rs index 9663b8d04aae..35f38e40d61d 100644 --- a/crates/polars-ops/src/chunked_array/strings/extract.rs +++ b/crates/polars-ops/src/chunked_array/strings/extract.rs @@ -13,7 +13,7 @@ fn extract_groups_array( arr: &Utf8ViewArray, reg: &Regex, names: &[&str], - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PolarsResult { let mut builders = (0..names.len()) .map(|_| MutablePlString::with_capacity(arr.len())) @@ -36,7 +36,7 @@ fn extract_groups_array( } let values = builders.into_iter().map(|a| a.freeze().boxed()).collect(); - Ok(StructArray::new(data_type.clone(), values, arr.validity().cloned()).boxed()) + Ok(StructArray::new(dtype.clone(), values, arr.validity().cloned()).boxed()) } #[cfg(feature = "extract_groups")] @@ -48,11 +48,14 @@ pub(super) fn extract_groups( let reg = Regex::new(pat)?; let n_fields = reg.captures_len(); if n_fields == 1 { - return StructChunked::from_series(ca.name(), &[Series::new_null(ca.name(), ca.len())]) - .map(|ca| ca.into_series()); + return StructChunked::from_series( + ca.name().clone(), + &[Series::new_null(ca.name().clone(), ca.len())], + ) + .map(|ca| ca.into_series()); } - let data_type = dtype.try_to_arrow(CompatLevel::newest())?; + let arrow_dtype = dtype.try_to_arrow(CompatLevel::newest())?; let DataType::Struct(fields) = dtype else { unreachable!() // Implementation error if it isn't a struct. }; @@ -63,10 +66,10 @@ pub(super) fn extract_groups( let chunks = ca .downcast_iter() - .map(|array| extract_groups_array(array, ®, &names, data_type.clone())) + .map(|array| extract_groups_array(array, ®, &names, arrow_dtype.clone())) .collect::>>()?; - Series::try_from((ca.name(), chunks)) + Series::try_from((ca.name().clone(), chunks)) } fn extract_group_reg_lit( @@ -153,21 +156,21 @@ pub(super) fn extract_group( let reg = Regex::new(pat)?; try_unary_mut_with_options(ca, |arr| extract_group_reg_lit(arr, ®, group_index)) } else { - Ok(StringChunked::full_null(ca.name(), ca.len())) + Ok(StringChunked::full_null(ca.name().clone(), ca.len())) } }, (1, _) => { if let Some(s) = ca.get(0) { try_unary_mut_with_options(pat, |pat| extract_group_array_lit(s, pat, group_index)) } else { - Ok(StringChunked::full_null(ca.name(), pat.len())) + Ok(StringChunked::full_null(ca.name().clone(), pat.len())) } }, (len_ca, len_pat) if len_ca == len_pat => try_binary_mut_with_options( ca, pat, |ca, pat| extract_group_binary(ca, pat, group_index), - ca.name(), + ca.name().clone(), ), _ => { polars_bail!(ComputeError: "ca(len: {}) and pat(len: {}) should either broadcast or have the same length", ca.len(), pat.len()) diff --git a/crates/polars-ops/src/chunked_array/strings/find_many.rs b/crates/polars-ops/src/chunked_array/strings/find_many.rs index 9bf0510e93d9..d56d8b3e014d 100644 --- a/crates/polars-ops/src/chunked_array/strings/find_many.rs +++ b/crates/polars-ops/src/chunked_array/strings/find_many.rs @@ -80,7 +80,8 @@ pub fn extract_many( ) -> PolarsResult { match patterns.dtype() { DataType::List(inner) if inner.is_string() => { - let mut builder = ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.len() * 2); + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2); let patterns = patterns.list().unwrap(); let (ca, patterns) = align_chunks_binary(ca, patterns); @@ -101,7 +102,8 @@ pub fn extract_many( DataType::String => { let patterns = patterns.str().unwrap(); let ac = build_ac(patterns, ascii_case_insensitive)?; - let mut builder = ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.len() * 2); + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.len() * 2); for arr in ca.downcast_iter() { for opt_val in arr.into_iter() { diff --git a/crates/polars-ops/src/chunked_array/strings/json_path.rs b/crates/polars-ops/src/chunked_array/strings/json_path.rs index ba2124e4a0be..a25ce1937332 100644 --- a/crates/polars-ops/src/chunked_array/strings/json_path.rs +++ b/crates/polars-ops/src/chunked_array/strings/json_path.rs @@ -54,7 +54,7 @@ pub trait Utf8JsonPathImpl: AsString { )?; unary_elementwise(ca, |opt_s| opt_s.and_then(|s| extract_json(&pat, s))) } else { - StringChunked::full_null(ca.name(), ca.len()) + StringChunked::full_null(ca.name().clone(), ca.len()) }; Ok(out) }, @@ -112,7 +112,7 @@ pub trait Utf8JsonPathImpl: AsString { ca.len(), ) .map_err(|e| polars_err!(ComputeError: "error deserializing JSON: {}", e))?; - Series::try_from(("", array)) + Series::try_from((PlSmallStr::EMPTY, array)) } fn json_path_select(&self, json_path: &str) -> PolarsResult { @@ -167,7 +167,7 @@ mod tests { #[test] fn test_json_infer() { let s = Series::new( - "json", + "json".into(), [ None, Some(r#"{"a": 1, "b": [{"c": 0}, {"c": 1}]}"#), @@ -177,10 +177,10 @@ mod tests { ); let ca = s.str().unwrap(); - let inner_dtype = DataType::Struct(vec![Field::new("c", DataType::Int64)]); + let inner_dtype = DataType::Struct(vec![Field::new("c".into(), DataType::Int64)]); let expected_dtype = DataType::Struct(vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::List(Box::new(inner_dtype))), + Field::new("a".into(), DataType::Int64), + Field::new("b".into(), DataType::List(Box::new(inner_dtype))), ]); assert_eq!(ca.json_infer(None).unwrap(), expected_dtype); @@ -192,7 +192,7 @@ mod tests { #[test] fn test_json_decode() { let s = Series::new( - "json", + "json".into(), [ None, Some(r#"{"a": 1, "b": "hello"}"#), @@ -203,14 +203,14 @@ mod tests { let ca = s.str().unwrap(); let expected_series = StructChunked::from_series( - "", + "".into(), &[ - Series::new("a", &[None, Some(1), Some(2), None]), - Series::new("b", &[None, Some("hello"), Some("goodbye"), None]), + Series::new("a".into(), &[None, Some(1), Some(2), None]), + Series::new("b".into(), &[None, Some("hello"), Some("goodbye"), None]), ], ) .unwrap() - .with_outer_validity_chunked(BooleanChunked::new("", [false, true, true, false])) + .with_outer_validity_chunked(BooleanChunked::new("".into(), [false, true, true, false])) .into_series(); let expected_dtype = expected_series.dtype().clone(); @@ -227,7 +227,7 @@ mod tests { #[test] fn test_json_path_select() { let s = Series::new( - "json", + "json".into(), [ None, Some(r#"{"a":1,"b":[{"c":0},{"c":1}]}"#), @@ -244,7 +244,7 @@ mod tests { .equals_missing(&s)); let b_series = Series::new( - "json", + "json".into(), [ None, Some(r#"[{"c":0},{"c":1}]"#), @@ -258,7 +258,10 @@ mod tests { .into_series() .equals_missing(&b_series)); - let c_series = Series::new("json", [None, Some(r#"[0,1]"#), Some(r#"[2,5]"#), None]); + let c_series = Series::new( + "json".into(), + [None, Some(r#"[0,1]"#), Some(r#"[2,5]"#), None], + ); assert!(ca .json_path_select("$.b[:].c") .unwrap() @@ -269,7 +272,7 @@ mod tests { #[test] fn test_json_path_extract() { let s = Series::new( - "json", + "json".into(), [ None, Some(r#"{"a":1,"b":[{"c":0},{"c":1}]}"#), @@ -280,11 +283,11 @@ mod tests { let ca = s.str().unwrap(); let c_series = Series::new( - "", + "".into(), [ None, - Some(Series::new("", &[0, 1])), - Some(Series::new("", &[2, 5])), + Some(Series::new("".into(), &[0, 1])), + Some(Series::new("".into(), &[2, 5])), None, ], ); diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index b9c1e3041967..1f2899764e4f 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -133,10 +133,10 @@ pub trait StringNameSpaceImpl: AsString { ca.contains(pat, strict) } }, - None => Ok(BooleanChunked::full_null(ca.name(), ca.len())), + None => Ok(BooleanChunked::full_null(ca.name().clone(), ca.len())), }, (1, _) if ca.null_count() == 1 => Ok(BooleanChunked::full_null( - ca.name(), + ca.name().clone(), ca.len().max(pat.len()), )), _ => { @@ -188,10 +188,13 @@ pub trait StringNameSpaceImpl: AsString { ca.find(pat, strict) } } else { - Ok(UInt32Chunked::full_null(ca.name(), ca.len())) + Ok(UInt32Chunked::full_null(ca.name().clone(), ca.len())) }; } else if ca.len() == 1 && ca.null_count() == 1 { - return Ok(UInt32Chunked::full_null(ca.name(), ca.len().max(pat.len()))); + return Ok(UInt32Chunked::full_null( + ca.name().clone(), + ca.len().max(pat.len()), + )); } if literal { Ok(broadcast_binary_elementwise( @@ -267,7 +270,7 @@ pub trait StringNameSpaceImpl: AsString { let out: BooleanChunked = if let Some(reg) = opt_reg { unary_elementwise_values(ca, |s| reg.is_match(s)) } else { - BooleanChunked::full_null(ca.name(), ca.len()) + BooleanChunked::full_null(ca.name().clone(), ca.len()) }; Ok(out) } @@ -292,7 +295,7 @@ pub trait StringNameSpaceImpl: AsString { Ok(rx) => Ok(unary_elementwise(ca, |opt_s| { opt_s.and_then(|s| rx.find(s)).map(|m| m.start() as u32) })), - Err(_) if !strict => Ok(UInt32Chunked::full_null(ca.name(), ca.len())), + Err(_) if !strict => Ok(UInt32Chunked::full_null(ca.name().clone(), ca.len())), Err(e) => Err(PolarsError::ComputeError( format!("Invalid regular expression: {}", e).into(), )), @@ -402,7 +405,8 @@ pub trait StringNameSpaceImpl: AsString { let ca = self.as_string(); let reg = Regex::new(pat)?; - let mut builder = ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); for arr in ca.downcast_iter() { for opt_s in arr { match opt_s { @@ -495,7 +499,8 @@ pub trait StringNameSpaceImpl: AsString { // A sqrt(n) regex cache is not too small, not too large. let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); - let mut builder = ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); binary_elementwise_for_each(ca, pat, |opt_s, opt_pat| match (opt_s, opt_pat) { (_, None) | (None, _) => builder.append_null(), (Some(s), Some(pat)) => { @@ -560,7 +565,7 @@ pub trait StringNameSpaceImpl: AsString { let out: UInt32Chunked = broadcast_try_binary_elementwise(ca, pat, op)?; - Ok(out.with_name(ca.name())) + Ok(out.with_name(ca.name().clone())) } /// Modify the strings to their lowercase equivalent. diff --git a/crates/polars-ops/src/chunked_array/strings/split.rs b/crates/polars-ops/src/chunked_array/strings/split.rs index 1902f6acf10b..d86e0efac2ae 100644 --- a/crates/polars-ops/src/chunked_array/strings/split.rs +++ b/crates/polars-ops/src/chunked_array/strings/split.rs @@ -65,6 +65,8 @@ where F: Fn(&'a str, &'a str) -> I, I: Iterator, { + use polars_utils::format_pl_smallstr; + let mut arrs = (0..n) .map(|_| MutableUtf8Array::::with_capacity(ca.len())) .collect::>(); @@ -143,11 +145,11 @@ where .into_iter() .enumerate() .map(|(i, mut arr)| { - Series::try_from((format!("field_{i}").as_str(), arr.as_box())).unwrap() + Series::try_from((format_pl_smallstr!("field_{i}"), arr.as_box())).unwrap() }) .collect::>(); - StructChunked::from_series(ca.name(), &fields) + StructChunked::from_series(ca.name().clone(), &fields) } pub fn split_helper<'a, F, I>(ca: &'a StringChunked, by: &'a StringChunked, op: F) -> ListChunked @@ -158,7 +160,7 @@ where if by.len() == 1 { if let Some(by) = by.get(0) { let mut builder = - ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); if by.is_empty() { ca.for_each(|opt_s| match opt_s { @@ -173,10 +175,11 @@ where } builder.finish() } else { - ListChunked::full_null_with_dtype(ca.name(), ca.len(), &DataType::String) + ListChunked::full_null_with_dtype(ca.name().clone(), ca.len(), &DataType::String) } } else { - let mut builder = ListStringChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + let mut builder = + ListStringChunkedBuilder::new(ca.name().clone(), ca.len(), ca.get_values_size()); binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { (Some(s), Some(by)) => { diff --git a/crates/polars-ops/src/chunked_array/strings/strip.rs b/crates/polars-ops/src/chunked_array/strings/strip.rs index c7468d238807..cd92704d6bfe 100644 --- a/crates/polars-ops/src/chunked_array/strings/strip.rs +++ b/crates/polars-ops/src/chunked_array/strings/strip.rs @@ -124,7 +124,7 @@ pub fn strip_prefix(ca: &StringChunked, prefix: &StringChunked) -> StringChunked Some(prefix) => unary_elementwise(ca, |opt_s| { opt_s.map(|s| s.strip_prefix(prefix).unwrap_or(s)) }), - _ => StringChunked::full_null(ca.name(), ca.len()), + _ => StringChunked::full_null(ca.name().clone(), ca.len()), }, _ => broadcast_binary_elementwise(ca, prefix, strip_prefix_binary), } @@ -136,7 +136,7 @@ pub fn strip_suffix(ca: &StringChunked, suffix: &StringChunked) -> StringChunked Some(suffix) => unary_elementwise(ca, |opt_s| { opt_s.map(|s| s.strip_suffix(suffix).unwrap_or(s)) }), - _ => StringChunked::full_null(ca.name(), ca.len()), + _ => StringChunked::full_null(ca.name().clone(), ca.len()), }, _ => broadcast_binary_elementwise(ca, suffix, strip_suffix_binary), } diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs index c9512f11bb2c..41fed212d439 100644 --- a/crates/polars-ops/src/chunked_array/strings/substring.rs +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -163,14 +163,14 @@ pub(super) fn substring( let str_val = ca.get(0); let offset = offset.get(0); unary_elementwise(length, |length| substring_ternary(str_val, offset, length)) - .with_name(ca.name()) + .with_name(ca.name().clone()) }, (_, 1, 1) => { let offset = offset.get(0); let length = length.get(0).unwrap_or(u64::MAX); let Some(offset) = offset else { - return StringChunked::full_null(ca.name(), ca.len()); + return StringChunked::full_null(ca.name().clone(), ca.len()); }; unsafe { @@ -184,7 +184,7 @@ pub(super) fn substring( let str_val = ca.get(0); let length = length.get(0); unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length)) - .with_name(ca.name()) + .with_name(ca.name().clone()) }, (1, len_b, len_c) if len_b == len_c => { let str_val = ca.get(0); @@ -225,7 +225,7 @@ pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult { let n = n.get(0); let Some(n) = n else { - return Ok(StringChunked::full_null(ca.name(), len)); + return Ok(StringChunked::full_null(ca.name().clone(), len)); }; Ok(unsafe { @@ -238,7 +238,7 @@ pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult { let str_val = ca.get(0); - Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name())) + Ok(unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name().clone())) }, (a, b) => { polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.head' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b); @@ -252,7 +252,7 @@ pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult { let n = n.get(0); let Some(n) = n else { - return Ok(StringChunked::full_null(ca.name(), len)); + return Ok(StringChunked::full_null(ca.name().clone(), len)); }; unsafe { ca.apply_views(|view, val| { @@ -264,7 +264,7 @@ pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> PolarsResult { let str_val = ca.get(0); - unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name()) + unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name().clone()) }, (a, b) => { polars_ensure!(a == b, ShapeMismatch: "lengths of arguments do not align in 'str.tail' got length: {} for column: {}, got length: {} for argument 'n'", a, ca.name(), b); diff --git a/crates/polars-ops/src/chunked_array/top_k.rs b/crates/polars-ops/src/chunked_array/top_k.rs index f5948d0c88a4..9772a5593be0 100644 --- a/crates/polars-ops/src/chunked_array/top_k.rs +++ b/crates/polars-ops/src/chunked_array/top_k.rs @@ -204,6 +204,7 @@ pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { Ok(ca.into_series()) }, DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_series()), + #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, _) => { let src = src.decimal().unwrap(); let ca = top_k_num_impl(src, k, descending); @@ -212,6 +213,7 @@ pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { Ok(lca.into_series()) }, DataType::Null => Ok(src.slice(0, k)), + #[cfg(feature = "dtype-struct")] DataType::Struct(_) => { // Fallback to more generic impl. top_k_by_impl(k, src, &[src.clone()], vec![descending]) diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index a5f0b0197e9f..2f5d6504eba7 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -15,6 +15,7 @@ pub type ChunkJoinOptIds = Vec; #[cfg(not(feature = "chunked_ids"))] pub type ChunkJoinIds = Vec; +use polars_core::export::once_cell::sync::Lazy; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -23,7 +24,7 @@ use serde::{Deserialize, Serialize}; pub struct JoinArgs { pub how: JoinType, pub validation: JoinValidation, - pub suffix: Option, + pub suffix: Option, pub slice: Option<(i64, usize)>, pub join_nulls: bool, pub coalesce: JoinCoalesce, @@ -57,6 +58,7 @@ impl JoinCoalesce { }, #[cfg(feature = "asof_join")] AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns), + IEJoin(_) => false, Cross => false, #[cfg(feature = "semi_anti_join")] Semi | Anti => false, @@ -94,13 +96,14 @@ impl JoinArgs { self } - pub fn with_suffix(mut self, suffix: Option) -> Self { + pub fn with_suffix(mut self, suffix: Option) -> Self { self.suffix = suffix; self } - pub fn suffix(&self) -> &str { - self.suffix.as_deref().unwrap_or("_right") + pub fn suffix(&self) -> &PlSmallStr { + static DEFAULT: Lazy = Lazy::new(|| PlSmallStr::from_static("_right")); + self.suffix.as_ref().unwrap_or(&*DEFAULT) } } @@ -118,6 +121,7 @@ pub enum JoinType { Semi, #[cfg(feature = "semi_anti_join")] Anti, + IEJoin(IEJoinOptions), } impl From for JoinArgs { @@ -136,6 +140,7 @@ impl Display for JoinType { Full { .. } => "FULL", #[cfg(feature = "asof_join")] AsOf(_) => "ASOF", + IEJoin(_) => "IEJOIN", Cross => "CROSS", #[cfg(feature = "semi_anti_join")] Semi => "SEMI", diff --git a/crates/polars-ops/src/frame/join/asof/default.rs b/crates/polars-ops/src/frame/join/asof/default.rs index c8c8c68094bf..609d4048bf78 100644 --- a/crates/polars-ops/src/frame/join/asof/default.rs +++ b/crates/polars-ops/src/frame/join/asof/default.rs @@ -15,7 +15,7 @@ where F: FnMut(T::Physical<'a>, T::Physical<'a>) -> bool, { if left.len() == left.null_count() || right.len() == right.null_count() { - return IdxCa::full_null("", left.len()); + return IdxCa::full_null(PlSmallStr::EMPTY, left.len()); } let mut out = vec![0; left.len()]; @@ -55,7 +55,7 @@ where } let bitmap = Bitmap::try_new(mask, out.len()).unwrap(); - IdxCa::from_vec_validity("", out, Some(bitmap)) + IdxCa::from_vec_validity(PlSmallStr::EMPTY, out, Some(bitmap)) } fn join_asof_forward<'a, T, F>(left: &'a T::Array, right: &'a T::Array, filter: F) -> IdxCa diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 4547fbe4f141..81b05a4b752d 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -1,6 +1,5 @@ use std::hash::Hash; -use ahash::RandomState; use hashbrown::HashMap; use num_traits::Zero; use polars_core::hashing::{ @@ -13,13 +12,14 @@ use polars_core::utils::flatten::flatten_nullable; use polars_core::utils::{_set_partition_size, split_and_flatten}; use polars_core::{with_match_physical_float_polars_type, IdBuildHasher, POOL}; use polars_utils::abs_diff::AbsDiff; +use polars_utils::aliases::PlRandomState; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; +use polars_utils::pl_str::PlSmallStr; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use polars_utils::unitvec; use rayon::prelude::*; -use smartstring::alias::String as SmartString; use super::*; @@ -259,7 +259,7 @@ where let split_by_right = split_and_flatten(by_right, n_threads); let offsets = compute_len_offsets(split_by_left.iter().map(|s| s.len())); - let hb = RandomState::default(); + let hb = PlRandomState::default(); let prep_by_left = prepare_bytes(&split_by_left, &hb); let prep_by_right = prepare_bytes(&split_by_right, &hb); let hash_tbls = build_tables(prep_by_right, false); @@ -600,11 +600,11 @@ pub trait AsofJoinBy: IntoDf { other: &DataFrame, left_on: &Series, right_on: &Series, - left_by: Vec, - right_by: Vec, + left_by: Vec, + right_by: Vec, strategy: AsofStrategy, tolerance: Option>, - suffix: Option<&str>, + suffix: Option, slice: Option<(i64, usize)>, coalesce: bool, ) -> PolarsResult { @@ -678,8 +678,9 @@ pub trait AsofJoinBy: IntoDf { let left = self_df.clone(); // SAFETY: join tuples are in bounds. - let right_df = - unsafe { proj_other_df.take_unchecked(&IdxCa::with_chunk("", right_join_tuples)) }; + let right_df = unsafe { + proj_other_df.take_unchecked(&IdxCa::with_chunk(PlSmallStr::EMPTY, right_join_tuples)) + }; _finish_join(left, right_df, suffix) } diff --git a/crates/polars-ops/src/frame/join/asof/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs index 07fdd69c7399..71e813cdac39 100644 --- a/crates/polars-ops/src/frame/join/asof/mod.rs +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -5,9 +5,9 @@ use std::borrow::Cow; use default::*; pub use groups::AsofJoinBy; use polars_core::prelude::*; +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use smartstring::alias::String as SmartString; #[cfg(feature = "dtype-categorical")] use super::_check_categorical_src; @@ -152,9 +152,9 @@ pub struct AsOfOptions { /// - "5m" /// - "2h15m" /// - "1d6h" - pub tolerance_str: Option, - pub left_by: Option>, - pub right_by: Option>, + pub tolerance_str: Option, + pub left_by: Option>, + pub right_by: Option>, } fn check_asof_columns( @@ -212,7 +212,7 @@ pub trait AsofJoin: IntoDf { right_key: &Series, strategy: AsofStrategy, tolerance: Option>, - suffix: Option, + suffix: Option, slice: Option<(i64, usize)>, coalesce: bool, ) -> PolarsResult { @@ -284,7 +284,7 @@ pub trait AsofJoin: IntoDf { // SAFETY: join tuples are in bounds. let right_df = unsafe { other.take_unchecked(&take_idx) }; - _finish_join(left, right_df, suffix.as_deref()) + _finish_join(left, right_df, suffix) } } diff --git a/crates/polars-ops/src/frame/join/cross_join.rs b/crates/polars-ops/src/frame/join/cross_join.rs index 1e1b1bcba497..c4290e262627 100644 --- a/crates/polars-ops/src/frame/join/cross_join.rs +++ b/crates/polars-ops/src/frame/join/cross_join.rs @@ -1,5 +1,5 @@ use polars_core::utils::{concat_df_unchecked, CustomIterTools, NoNull}; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use super::*; @@ -99,7 +99,7 @@ pub trait CrossJoin: IntoDf { fn _cross_join_with_names( &self, other: &DataFrame, - names: &[SmartString], + names: &[PlSmallStr], ) -> PolarsResult { let (mut l_df, r_df) = self.cross_join_dfs(other, None, false)?; @@ -111,7 +111,7 @@ pub trait CrossJoin: IntoDf { .zip(names) .for_each(|(s, name)| { if s.name() != name { - s.rename(name); + s.rename(name.clone()); } }); } @@ -122,7 +122,7 @@ pub trait CrossJoin: IntoDf { fn cross_join( &self, other: &DataFrame, - suffix: Option<&str>, + suffix: Option, slice: Option<(i64, usize)>, ) -> PolarsResult { let (l_df, r_df) = self.cross_join_dfs(other, slice, true)?; diff --git a/crates/polars-ops/src/frame/join/dispatch_left_right.rs b/crates/polars-ops/src/frame/join/dispatch_left_right.rs index d8dd5396b1e2..f5c91de88a74 100644 --- a/crates/polars-ops/src/frame/join/dispatch_left_right.rs +++ b/crates/polars-ops/src/frame/join/dispatch_left_right.rs @@ -8,12 +8,12 @@ pub(super) fn left_join_from_series( s_right: &Series, args: JoinArgs, verbose: bool, - drop_names: Option<&[&str]>, + drop_names: Option>, ) -> PolarsResult { let (df_left, df_right) = materialize_left_join_from_series( left, right, s_left, s_right, &args, verbose, drop_names, )?; - _finish_join(df_left, df_right, args.suffix.as_deref()) + _finish_join(df_left, df_right, args.suffix) } pub(super) fn right_join_from_series( @@ -23,13 +23,13 @@ pub(super) fn right_join_from_series( s_right: &Series, args: JoinArgs, verbose: bool, - drop_names: Option<&[&str]>, + drop_names: Option>, ) -> PolarsResult { // Swap the order of tables to do a right join. let (df_right, df_left) = materialize_left_join_from_series( right, left, s_right, s_left, &args, verbose, drop_names, )?; - _finish_join(df_left, df_right, args.suffix.as_deref()) + _finish_join(df_left, df_right, args.suffix) } pub fn materialize_left_join_from_series( @@ -39,7 +39,7 @@ pub fn materialize_left_join_from_series( s_right: &Series, args: &JoinArgs, verbose: bool, - drop_names: Option<&[&str]>, + drop_names: Option>, ) -> PolarsResult<(DataFrame, DataFrame)> { #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs index 2e4d38e2af0d..5840b853425c 100644 --- a/crates/polars-ops/src/frame/join/general.rs +++ b/crates/polars-ops/src/frame/join/general.rs @@ -1,12 +1,14 @@ +use polars_utils::format_pl_smallstr; + use super::*; use crate::series::coalesce_series; -pub fn _join_suffix_name(name: &str, suffix: &str) -> String { - format!("{name}{suffix}") +pub fn _join_suffix_name(name: &str, suffix: &str) -> PlSmallStr { + format_pl_smallstr!("{name}{suffix}") } -fn get_suffix(suffix: Option<&str>) -> &str { - suffix.unwrap_or("_right") +fn get_suffix(suffix: Option) -> PlSmallStr { + suffix.unwrap_or_else(|| PlSmallStr::from_static("_right")) } /// Utility method to finish a join. @@ -14,7 +16,7 @@ fn get_suffix(suffix: Option<&str>) -> &str { pub fn _finish_join( mut df_left: DataFrame, mut df_right: DataFrame, - suffix: Option<&str>, + suffix: Option, ) -> PolarsResult { let mut left_names = PlHashSet::with_capacity(df_left.width()); @@ -32,8 +34,8 @@ pub fn _finish_join( let suffix = get_suffix(suffix); for name in rename_strs { - let new_name = _join_suffix_name(&name, suffix); - df_right.rename(&name, new_name.as_str()).map_err(|_| { + let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); + df_right.rename(&name, new_name.clone()).map_err(|_| { polars_err!(Duplicate: "column with name '{}' already exists\n\n\ You may want to try:\n\ - renaming the column prior to joining\n\ @@ -48,9 +50,9 @@ pub fn _finish_join( pub fn _coalesce_full_join( mut df: DataFrame, - keys_left: &[&str], - keys_right: &[&str], - suffix: Option<&str>, + keys_left: &[PlSmallStr], + keys_right: &[PlSmallStr], + suffix: Option, df_left: &DataFrame, ) -> DataFrame { // No need to allocate the schema because we already @@ -67,14 +69,14 @@ pub fn _coalesce_full_join( // SAFETY: we maintain invariants. let columns = unsafe { df.get_columns_mut() }; - for (&l, &r) in keys_left.iter().zip(keys_right.iter()) { - let pos_l = schema.get_full(l).unwrap().0; + let suffix = get_suffix(suffix); + for (l, r) in keys_left.iter().zip(keys_right.iter()) { + let pos_l = schema.get_full(l.as_str()).unwrap().0; - let r = if l == r || schema_left.contains(r) { - let suffix = get_suffix(suffix); - Cow::Owned(_join_suffix_name(r, suffix)) + let r = if l == r || schema_left.contains(r.as_str()) { + _join_suffix_name(r.as_str(), suffix.as_str()) } else { - Cow::Borrowed(r) + r.clone() }; let pos_r = schema.get_full(&r).unwrap().0; diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 65e6d0a56dce..35e4ea9403af 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -147,8 +147,8 @@ pub trait JoinDispatch: IntoDf { join_idx_l.slice(offset, len); join_idx_r.slice(offset, len); } - let idx_ca_l = IdxCa::with_chunk("", join_idx_l); - let idx_ca_r = IdxCa::with_chunk("", join_idx_r); + let idx_ca_l = IdxCa::with_chunk(PlSmallStr::EMPTY, join_idx_l); + let idx_ca_r = IdxCa::with_chunk(PlSmallStr::EMPTY, join_idx_r); // Take the left and right dataframes by join tuples let (df_left, df_right) = POOL.join( @@ -157,13 +157,13 @@ pub trait JoinDispatch: IntoDf { ); let coalesce = args.coalesce.coalesce(&JoinType::Full); - let out = _finish_join(df_left, df_right, args.suffix.as_deref()); + let out = _finish_join(df_left, df_right, args.suffix.clone()); if coalesce { Ok(_coalesce_full_join( out?, - &[s_left.name()], - &[s_right.name()], - args.suffix.as_deref(), + &[s_left.name().clone()], + &[s_right.name().clone()], + args.suffix.clone(), df_self, )) } else { diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index 8a09002ba54d..a8093873ea51 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -2,6 +2,7 @@ use arrow::array::PrimitiveArray; use polars_core::series::BitRepr; use polars_core::utils::split; use polars_core::with_match_physical_float_polars_type; +use polars_utils::aliases::PlRandomState; use polars_utils::hashing::DirtyHash; use polars_utils::nulls::IsNull; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; @@ -509,7 +510,7 @@ where #[cfg(feature = "asof_join")] pub fn prepare_bytes<'a>( been_split: &'a [BinaryChunked], - hb: &RandomState, + hb: &PlRandomState, ) -> Vec>> { POOL.install(|| { been_split @@ -536,7 +537,7 @@ fn prepare_binary<'a, T>( Vec>>, Vec>>, bool, - RandomState, + PlRandomState, ) where T: PolarsDataType, @@ -547,7 +548,7 @@ where } else { (ca, other, false) }; - let hb = RandomState::default(); + let hb = PlRandomState::default(); let bh_a = a.to_bytes_hashes(true, hb.clone()); let bh_b = b.to_bytes_hashes(true, hb.clone()); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs index 58bdd286a814..f01c99529aea 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -1,7 +1,7 @@ use polars_core::utils::flatten; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs index c5efc1f9550f..40d9deadc931 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs @@ -10,8 +10,8 @@ use super::*; pub(crate) fn create_hash_and_keys_threaded_vectorized( iters: Vec, - build_hasher: Option, -) -> (Vec>, RandomState) + build_hasher: Option, +) -> (Vec>, PlRandomState) where I: IntoIterator + Send, I::IntoIter: TrustedLen, diff --git a/crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs b/crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs new file mode 100644 index 000000000000..2c741a797f11 --- /dev/null +++ b/crates/polars-ops/src/frame/join/iejoin/filtered_bit_array.rs @@ -0,0 +1,49 @@ +use std::cmp::min; + +use arrow::bitmap::MutableBitmap; + +/// Bit array with a filter to speed up searching for set bits when sparse, +/// based on section 4.1 from Khayyat et al. 2015, +/// "Lightning Fast and Space Efficient Inequality Joins" +pub struct FilteredBitArray { + bit_array: MutableBitmap, + filter: MutableBitmap, +} + +impl FilteredBitArray { + const CHUNK_SIZE: usize = 1024; + + pub fn from_len_zeroed(len: usize) -> Self { + Self { + bit_array: MutableBitmap::from_len_zeroed(len), + filter: MutableBitmap::from_len_zeroed(len.div_ceil(Self::CHUNK_SIZE)), + } + } + + pub unsafe fn set_bit_unchecked(&mut self, index: usize) { + self.bit_array.set_unchecked(index, true); + self.filter.set_unchecked(index / Self::CHUNK_SIZE, true); + } + + pub unsafe fn on_set_bits_from(&self, start: usize, mut action: F) + where + F: FnMut(usize), + { + let start_chunk = start / Self::CHUNK_SIZE; + let mut chunk_offset = start % Self::CHUNK_SIZE; + for chunk_idx in start_chunk..self.filter.len() { + if self.filter.get_unchecked(chunk_idx) { + // There are some set bits in this chunk + let start = chunk_idx * Self::CHUNK_SIZE + chunk_offset; + let end = min((chunk_idx + 1) * Self::CHUNK_SIZE, self.bit_array.len()); + for bit_idx in start..end { + // SAFETY: `bit_idx` is always less than `self.bit_array.len()` + if self.bit_array.get_unchecked(bit_idx) { + action(bit_idx); + } + } + } + chunk_offset = 0; + } + } +} diff --git a/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs b/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs new file mode 100644 index 000000000000..67aa4cf6393b --- /dev/null +++ b/crates/polars-ops/src/frame/join/iejoin/l1_l2.rs @@ -0,0 +1,262 @@ +use polars_core::chunked_array::ChunkedArray; +use polars_core::datatypes::{IdxCa, PolarsNumericType}; +use polars_core::prelude::Series; +use polars_core::with_match_physical_numeric_polars_type; +use polars_error::PolarsResult; +use polars_utils::total_ord::TotalOrd; +use polars_utils::IdxSize; + +use super::*; + +/// Create a vector of L1 items from the array of LHS x values concatenated with RHS x values +/// and their ordering. +pub(super) fn build_l1_array( + ca: &ChunkedArray, + order: &IdxCa, + right_df_offset: IdxSize, +) -> PolarsResult>> +where + T: PolarsNumericType, +{ + assert_eq!(order.null_count(), 0); + assert_eq!(ca.chunks().len(), 1); + let arr = ca.downcast_get(0).unwrap(); + // Even if there are nulls, they will not be selected by order. + let values = arr.values().as_slice(); + + let mut array: Vec> = Vec::with_capacity(ca.len()); + + for order_arr in order.downcast_iter() { + for index in order_arr.values().as_slice().iter().copied() { + debug_assert!(arr.get(index as usize).is_some()); + let value = unsafe { *values.get_unchecked(index as usize) }; + let row_index = if index < right_df_offset { + // Row from LHS + index as i64 + 1 + } else { + // Row from RHS + -((index - right_df_offset) as i64) - 1 + }; + array.push(L1Item { row_index, value }); + } + } + + Ok(array) +} + +pub(super) fn build_l2_array(s: &Series, order: &[IdxSize]) -> PolarsResult> { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + build_l2_array_impl::<$T>(s.as_ref().as_ref(), order) + }) +} + +/// Create a vector of L2 items from the array of y values ordered according to the L1 order, +/// and their ordering. We don't need to store actual y values but only track whether we're at +/// the end of a run of equal values. +fn build_l2_array_impl(ca: &ChunkedArray, order: &[IdxSize]) -> PolarsResult> +where + T: PolarsNumericType, + T::Native: TotalOrd, +{ + assert_eq!(ca.chunks().len(), 1); + + let mut array = Vec::with_capacity(ca.len()); + let mut prev_index = 0; + let mut prev_value = T::Native::default(); + + let arr = ca.downcast_get(0).unwrap(); + // Even if there are nulls, they will not be selected by order. + let values = arr.values().as_slice(); + + for (i, l1_index) in order.iter().copied().enumerate() { + debug_assert!(arr.get(l1_index as usize).is_some()); + let value = unsafe { *values.get_unchecked(l1_index as usize) }; + if i > 0 { + array.push(L2Item { + l1_index: prev_index, + run_end: value.tot_ne(&prev_value), + }); + } + prev_index = l1_index; + prev_value = value; + } + if !order.is_empty() { + array.push(L2Item { + l1_index: prev_index, + run_end: true, + }); + } + Ok(array) +} + +/// Item in L1 array used in the IEJoin algorithm +#[derive(Clone, Copy, Debug)] +pub(super) struct L1Item { + /// 1 based index for entries from the LHS df, or -1 based index for entries from the RHS + pub(super) row_index: i64, + /// X value + pub(super) value: T, +} + +/// Item in L2 array used in the IEJoin algorithm +#[derive(Clone, Copy, Debug)] +pub(super) struct L2Item { + /// Corresponding index into the L1 array of + pub(super) l1_index: IdxSize, + /// Whether this is the end of a run of equal y values + pub(super) run_end: bool, +} + +pub(super) trait L1Array { + unsafe fn process_entry( + &self, + l1_index: usize, + bit_array: &mut FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64; + + unsafe fn process_lhs_entry( + &self, + l1_index: usize, + bit_array: &FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64; + + unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray); +} + +/// Find the position in the L1 array where we should begin checking for matches, +/// given the index in L1 corresponding to the current position in L2. +unsafe fn find_search_start_index( + l1_array: &[L1Item], + index: usize, + operator: InequalityOperator, +) -> usize +where + T: NumericNative, + T: TotalOrd, +{ + let sub_l1 = l1_array.get_unchecked_release(index..); + let value = l1_array.get_unchecked_release(index).value; + + match operator { + InequalityOperator::Gt => { + sub_l1.partition_point_exponential(|a| a.value.tot_ge(&value)) + index + }, + InequalityOperator::Lt => { + sub_l1.partition_point_exponential(|a| a.value.tot_le(&value)) + index + }, + InequalityOperator::GtEq => { + sub_l1.partition_point_exponential(|a| value.tot_lt(&a.value)) + index + }, + InequalityOperator::LtEq => { + sub_l1.partition_point_exponential(|a| value.tot_gt(&a.value)) + index + }, + } +} + +fn find_matches_in_l1( + l1_array: &[L1Item], + l1_index: usize, + row_index: i64, + bit_array: &FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, +) -> i64 +where + T: NumericNative, + T: TotalOrd, +{ + debug_assert!(row_index > 0); + let mut match_count = 0; + + // This entry comes from the left hand side DataFrame. + // Find all following entries in L1 (meaning they satisfy the first operator) + // that have already been visited (so satisfy the second operator). + // Because we use a stable sort for l2, we know that we won't find any + // matches for duplicate y values when traversing forwards in l1. + let start_index = unsafe { find_search_start_index(l1_array, l1_index, op1) }; + unsafe { + bit_array.on_set_bits_from(start_index, |set_bit: usize| { + // SAFETY + // set bit is within bounds. + let right_row_index = l1_array.get_unchecked_release(set_bit).row_index; + debug_assert!(right_row_index < 0); + left_row_ids.push((row_index - 1) as IdxSize); + right_row_ids.push((-right_row_index) as IdxSize - 1); + match_count += 1; + }) + }; + + match_count +} + +impl L1Array for Vec> +where + T: NumericNative, +{ + unsafe fn process_entry( + &self, + l1_index: usize, + bit_array: &mut FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64 { + let row_index = self.get_unchecked_release(l1_index).row_index; + let from_lhs = row_index > 0; + if from_lhs { + find_matches_in_l1( + self, + l1_index, + row_index, + bit_array, + op1, + left_row_ids, + right_row_ids, + ) + } else { + bit_array.set_bit_unchecked(l1_index); + 0 + } + } + + unsafe fn process_lhs_entry( + &self, + l1_index: usize, + bit_array: &FilteredBitArray, + op1: InequalityOperator, + left_row_ids: &mut Vec, + right_row_ids: &mut Vec, + ) -> i64 { + let row_index = self.get_unchecked_release(l1_index).row_index; + let from_lhs = row_index > 0; + if from_lhs { + find_matches_in_l1( + self, + l1_index, + row_index, + bit_array, + op1, + left_row_ids, + right_row_ids, + ) + } else { + 0 + } + } + + unsafe fn mark_visited(&self, index: usize, bit_array: &mut FilteredBitArray) { + let from_lhs = self.get_unchecked_release(index).row_index > 0; + // We only mark RHS entries as visited, + // so that we don't try to match LHS entries with other LHS entries. + if !from_lhs { + bit_array.set_bit_unchecked(index); + } + } +} diff --git a/crates/polars-ops/src/frame/join/iejoin/mod.rs b/crates/polars-ops/src/frame/join/iejoin/mod.rs new file mode 100644 index 000000000000..d0698018c5bb --- /dev/null +++ b/crates/polars-ops/src/frame/join/iejoin/mod.rs @@ -0,0 +1,383 @@ +mod filtered_bit_array; +mod l1_l2; + +use filtered_bit_array::FilteredBitArray; +use l1_l2::*; +use polars_core::chunked_array::ChunkedArray; +use polars_core::datatypes::{IdxCa, NumericNative, PolarsNumericType}; +use polars_core::frame::DataFrame; +use polars_core::prelude::*; +use polars_core::utils::{_set_partition_size, split}; +use polars_core::{with_match_physical_numeric_polars_type, POOL}; +use polars_error::{polars_err, PolarsResult}; +use polars_utils::binary_search::ExponentialSearch; +use polars_utils::itertools::Itertools; +use polars_utils::slice::GetSaferUnchecked; +use polars_utils::total_ord::TotalEq; +use polars_utils::IdxSize; +use rayon::prelude::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::frame::_finish_join; + +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum InequalityOperator { + #[default] + Lt, + LtEq, + Gt, + GtEq, +} + +impl InequalityOperator { + fn is_strict(&self) -> bool { + matches!(self, InequalityOperator::Gt | InequalityOperator::Lt) + } +} +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct IEJoinOptions { + pub operator1: InequalityOperator, + pub operator2: InequalityOperator, +} + +#[allow(clippy::too_many_arguments)] +fn ie_join_impl_t( + slice: Option<(i64, usize)>, + l1_order: IdxCa, + l2_order: &[IdxSize], + op1: InequalityOperator, + op2: InequalityOperator, + x: Series, + y_ordered_by_x: Series, + left_height: usize, +) -> PolarsResult<(Vec, Vec)> { + // Create a bit array with order corresponding to L1, + // denoting which entries have been visited while traversing L2. + let mut bit_array = FilteredBitArray::from_len_zeroed(l1_order.len()); + + let mut left_row_idx: Vec = vec![]; + let mut right_row_idx: Vec = vec![]; + + let slice_end = match slice { + Some((offset, len)) if offset >= 0 => Some(offset.saturating_add_unsigned(len as u64)), + _ => None, + }; + let mut match_count = 0; + + let ca: &ChunkedArray = x.as_ref().as_ref(); + let l1_array = build_l1_array(ca, &l1_order, left_height as IdxSize)?; + + if op2.is_strict() { + // For strict inequalities, we rely on using a stable sort of l2 so that + // p values only increase as we traverse a run of equal y values. + // To handle inclusive comparisons in x and duplicate x values we also need the + // sort of l1 to be stable, so that the left hand side entries come before the right + // hand side entries (as we mark visited entries from the right hand side). + for &p in l2_order { + match_count += unsafe { + l1_array.process_entry( + p as usize, + &mut bit_array, + op1, + &mut left_row_idx, + &mut right_row_idx, + ) + }; + + if slice_end.is_some_and(|end| match_count >= end) { + break; + } + } + } else { + let l2_array = build_l2_array(&y_ordered_by_x, l2_order)?; + + // For non-strict inequalities in l2, we need to track runs of equal y values and only + // check for matches after we reach the end of the run and have marked all rhs entries + // in the run as visited. + let mut run_start = 0; + + for i in 0..l2_array.len() { + // Elide bound checks + unsafe { + let item = l2_array.get_unchecked_release(i); + let p = item.l1_index; + l1_array.mark_visited(p as usize, &mut bit_array); + + if item.run_end { + for l2_item in l2_array.get_unchecked_release(run_start..i + 1) { + let p = l2_item.l1_index; + match_count += l1_array.process_lhs_entry( + p as usize, + &bit_array, + op1, + &mut left_row_idx, + &mut right_row_idx, + ); + } + + run_start = i + 1; + + if slice_end.is_some_and(|end| match_count >= end) { + break; + } + } + } + } + } + Ok((left_row_idx, right_row_idx)) +} + +pub(super) fn iejoin_par( + left: &DataFrame, + right: &DataFrame, + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + suffix: Option, + slice: Option<(i64, usize)>, +) -> PolarsResult { + let l1_descending = matches!( + options.operator1, + InequalityOperator::Gt | InequalityOperator::GtEq + ); + + let l1_sort_options = SortOptions::default() + .with_maintain_order(true) + .with_nulls_last(false) + .with_order_descending(l1_descending); + + let sl = &selected_left[0]; + let l1_s_l = sl + .arg_sort(l1_sort_options) + .slice(sl.null_count() as i64, sl.len() - sl.null_count()); + + let sr = &selected_right[0]; + let l1_s_r = sr + .arg_sort(l1_sort_options) + .slice(sr.null_count() as i64, sr.len() - sr.null_count()); + + // Because we do a cartesian product, the number of partitions is squared. + // We take the sqrt, but we don't expect every partition to produce results and work can be + // imbalanced, so we multiply the number of partitions by 2, which leads to 2^2= 4 + let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2; + let splitted_a = split(&l1_s_l, n_partitions); + let splitted_b = split(&l1_s_r, n_partitions); + + let cartesian_prod = splitted_a + .iter() + .flat_map(|l| splitted_b.iter().map(move |r| (l, r))) + .collect::>(); + + let iter = cartesian_prod.par_iter().map(|(l_l1_idx, r_l1_idx)| { + if l_l1_idx.is_empty() || r_l1_idx.is_empty() { + return Ok(None); + } + fn get_extrema<'a>( + l1_idx: &'a IdxCa, + s: &'a Series, + ) -> Option<(AnyValue<'a>, AnyValue<'a>)> { + let first = l1_idx.first()?; + let last = l1_idx.last()?; + + let start = s.get(first as usize).unwrap(); + let end = s.get(last as usize).unwrap(); + + Some(if start < end { + (start, end) + } else { + (end, start) + }) + } + let Some((min_l, max_l)) = get_extrema(l_l1_idx, sl) else { + return Ok(None); + }; + let Some((min_r, max_r)) = get_extrema(r_l1_idx, sr) else { + return Ok(None); + }; + + let include_block = match options.operator1 { + InequalityOperator::Lt => min_l < max_r, + InequalityOperator::LtEq => min_l <= max_r, + InequalityOperator::Gt => max_l > min_r, + InequalityOperator::GtEq => max_l >= min_r, + }; + + if include_block { + let (l, r) = unsafe { + ( + selected_left + .iter() + .map(|s| s.take_unchecked(l_l1_idx)) + .collect_vec(), + selected_right + .iter() + .map(|s| s.take_unchecked(r_l1_idx)) + .collect_vec(), + ) + }; + + // Compute the row indexes + let (idx_l, idx_r) = iejoin_tuples(l, r, options, None)?; + + if idx_l.is_empty() { + return Ok(None); + } + + // These are row indexes in the slices we have given, so we use those to gather in the + // original l1 offset arrays. This gives us indexes in the original tables. + unsafe { + Ok(Some(( + l_l1_idx.take_unchecked(&idx_l), + r_l1_idx.take_unchecked(&idx_r), + ))) + } + } else { + Ok(None) + } + }); + + let row_indices = POOL.install(|| iter.collect::>>())?; + + let mut left_idx = IdxCa::default(); + let mut right_idx = IdxCa::default(); + for (l, r) in row_indices.into_iter().flatten() { + left_idx.append(&l)?; + right_idx.append(&r)?; + } + if let Some((offset, end)) = slice { + left_idx = left_idx.slice(offset, end); + right_idx = right_idx.slice(offset, end); + } + + unsafe { materialize_join(left, right, &left_idx, &right_idx, suffix) } +} + +pub(super) fn iejoin( + left: &DataFrame, + right: &DataFrame, + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + suffix: Option, + slice: Option<(i64, usize)>, +) -> PolarsResult { + let (left_row_idx, right_row_idx) = + iejoin_tuples(selected_left, selected_right, options, slice)?; + unsafe { materialize_join(left, right, &left_row_idx, &right_row_idx, suffix) } +} + +unsafe fn materialize_join( + left: &DataFrame, + right: &DataFrame, + left_row_idx: &IdxCa, + right_row_idx: &IdxCa, + suffix: Option, +) -> PolarsResult { + let (join_left, join_right) = { + POOL.join( + || left.take_unchecked(left_row_idx), + || right.take_unchecked(right_row_idx), + ) + }; + + _finish_join(join_left, join_right, suffix) +} + +/// Inequality join. Matches rows between two DataFrames using two inequality operators +/// (one of [<, <=, >, >=]). +/// Based on Khayyat et al. 2015, "Lightning Fast and Space Efficient Inequality Joins" +/// and extended to work with duplicate values. +fn iejoin_tuples( + selected_left: Vec, + selected_right: Vec, + options: &IEJoinOptions, + slice: Option<(i64, usize)>, +) -> PolarsResult<(IdxCa, IdxCa)> { + if selected_left.len() != 2 { + return Err( + polars_err!(ComputeError: "IEJoin requires exactly two expressions from the left DataFrame"), + ); + }; + if selected_right.len() != 2 { + return Err( + polars_err!(ComputeError: "IEJoin requires exactly two expressions from the right DataFrame"), + ); + }; + + let op1 = options.operator1; + let op2 = options.operator2; + + // Determine the sort order based on the comparison operators used. + // We want to sort L1 so that "x[i] op1 x[j]" is true for j > i, + // and L2 so that "y[i] op2 y[j]" is true for j < i + // (except in the case of duplicates and strict inequalities). + // Note that the algorithms published in Khayyat et al. have incorrect logic for + // determining whether to sort descending. + let l1_descending = matches!(op1, InequalityOperator::Gt | InequalityOperator::GtEq); + let l2_descending = matches!(op2, InequalityOperator::Lt | InequalityOperator::LtEq); + + let mut x = selected_left[0].to_physical_repr().into_owned(); + let left_height = x.len(); + + x.extend(&selected_right[0].to_physical_repr())?; + // Rechunk because we will gather. + let x = x.rechunk(); + + let mut y = selected_left[1].to_physical_repr().into_owned(); + y.extend(&selected_right[1].to_physical_repr())?; + // Rechunk because we will gather. + let y = y.rechunk(); + + let l1_sort_options = SortOptions::default() + .with_maintain_order(true) + .with_nulls_last(false) + .with_order_descending(l1_descending); + // Get ordering of x, skipping any null entries as these cannot be matches + let l1_order = x + .arg_sort(l1_sort_options) + .slice(x.null_count() as i64, x.len() - x.null_count()); + + let y_ordered_by_x = unsafe { y.take_unchecked(&l1_order) }; + let l2_sort_options = SortOptions::default() + .with_maintain_order(true) + .with_nulls_last(false) + .with_order_descending(l2_descending); + // Get the indexes into l1, ordered by y values. + // l2_order is the same as "p" from Khayyat et al. + let l2_order = y_ordered_by_x + .arg_sort(l2_sort_options) + .slice( + y_ordered_by_x.null_count() as i64, + y_ordered_by_x.len() - y_ordered_by_x.null_count(), + ) + .rechunk(); + let l2_order = l2_order.downcast_get(0).unwrap().values().as_slice(); + + let (left_row_idx, right_row_idx) = with_match_physical_numeric_polars_type!(x.dtype(), |$T| { + ie_join_impl_t::<$T>( + slice, + l1_order, + l2_order, + op1, + op2, + x, + y_ordered_by_x, + left_height + ) + })?; + + debug_assert_eq!(left_row_idx.len(), right_row_idx.len()); + let left_row_idx = IdxCa::from_vec("".into(), left_row_idx); + let right_row_idx = IdxCa::from_vec("".into(), right_row_idx); + let (left_row_idx, right_row_idx) = match slice { + None => (left_row_idx, right_row_idx), + Some((offset, len)) => ( + left_row_idx.slice(offset, len), + right_row_idx.slice(offset, len), + ), + }; + Ok((left_row_idx, right_row_idx)) +} diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs index d368ef5f5159..a9f02c2904cd 100644 --- a/crates/polars-ops/src/frame/join/merge_sorted.rs +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -38,7 +38,7 @@ pub fn _merge_sorted_dfs( let out = merge_series(&lhs_phys, &rhs_phys, &merge_indicator)?; let mut out = out.cast(lhs.dtype()).unwrap(); - out.rename(lhs.name()); + out.rename(lhs.name().clone()); Ok(out) }) .collect::>()?; @@ -81,7 +81,7 @@ fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> PolarsR .zip(rhs.fields_as_series()) .map(|(lhs, rhs)| merge_series(lhs, &rhs, merge_indicator)) .collect::>>()?; - StructChunked::from_series("", &new_fields) + StructChunked::from_series(PlSmallStr::EMPTY, &new_fields) .unwrap() .into_series() }, diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 229e13457f81..433bffd232dd 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -7,6 +7,7 @@ mod cross_join; mod dispatch_left_right; mod general; mod hash_join; +mod iejoin; #[cfg(feature = "merge_sorted")] mod merge_sorted; @@ -14,7 +15,6 @@ use std::borrow::Cow; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; -use ahash::RandomState; pub use args::*; use arrow::trusted_len::TrustedLen; #[cfg(feature = "asof_join")] @@ -29,6 +29,7 @@ use general::create_chunked_index_mapping; pub use general::{_coalesce_full_join, _finish_join, _join_suffix_name}; pub use hash_join::*; use hashbrown::hash_map::{Entry, RawEntryMut}; +pub use iejoin::{IEJoinOptions, InequalityOperator}; #[cfg(feature = "merge_sorted")] pub use merge_sorted::_merge_sorted_dfs; use polars_core::hashing::_HASHMAP_INIT_SIZE; @@ -82,17 +83,13 @@ pub trait DataFrameJoinOps: IntoDf { /// | Pear | 12 | 115 | /// +--------+----------------------+---------------------+ /// ``` - fn join( + fn join( &self, other: &DataFrame, - left_on: I, - right_on: I, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, args: JoinArgs, - ) -> PolarsResult - where - I: IntoIterator, - S: AsRef, - { + ) -> PolarsResult { let df_left = self.to_df(); let selected_left = df_left.select_series(left_on)?; let selected_right = other.select_series(right_on)?; @@ -115,7 +112,7 @@ pub trait DataFrameJoinOps: IntoDf { #[cfg(feature = "cross_join")] if let JoinType::Cross = args.how { - return left_df.cross_join(other, args.suffix.as_deref(), args.slice); + return left_df.cross_join(other, args.suffix.clone(), args.slice); } // Clear literals if a frame is empty. Otherwise we could get an oob @@ -196,17 +193,36 @@ pub trait DataFrameJoinOps: IntoDf { Err(_) => { let (ca_left, ca_right) = make_categoricals_compatible(l.categorical()?, r.categorical()?)?; - *l = ca_left.into_series().with_name(l.name()); - *r = ca_right.into_series().with_name(r.name()); + *l = ca_left.into_series().with_name(l.name().clone()); + *r = ca_right.into_series().with_name(r.name().clone()); }, } } + if let JoinType::IEJoin(options) = args.how { + let func = if POOL.current_num_threads() > 1 && !left_df.is_empty() && !other.is_empty() + { + iejoin::iejoin_par + } else { + iejoin::iejoin + }; + return func( + left_df, + other, + selected_left, + selected_right, + &options, + args.suffix, + args.slice, + ); + } + // Single keys. if selected_left.len() == 1 { let s_left = &selected_left[0]; let s_right = &selected_right[0]; - let drop_names: Option<&[&str]> = if should_coalesce { None } else { Some(&[]) }; + let drop_names: Option> = + if should_coalesce { None } else { Some(vec![]) }; return match args.how { JoinType::Inner => left_df ._inner_join_from_series(other, s_left, s_right, args, _verbose, drop_names), @@ -255,7 +271,7 @@ pub trait DataFrameJoinOps: IntoDf { right_by, options.strategy, options.tolerance, - args.suffix.as_deref(), + args.suffix.clone(), args.slice, should_coalesce, ), @@ -273,6 +289,9 @@ pub trait DataFrameJoinOps: IntoDf { panic!("expected by arguments on both sides") }, }, + JoinType::IEJoin(_) => { + unreachable!() + }, JoinType::Cross => { unreachable!() }, @@ -283,9 +302,12 @@ pub trait DataFrameJoinOps: IntoDf { let rhs_keys = prepare_keys_multiple(&selected_right, args.join_nulls)?.into_series(); let drop_names = if should_coalesce { - Some(selected_right.iter().map(|s| s.name()).collect::>()) + selected_right + .iter() + .map(|s| s.name().clone()) + .collect::>() } else { - Some(vec![]) + vec![] }; // Multiple keys. @@ -294,11 +316,17 @@ pub trait DataFrameJoinOps: IntoDf { JoinType::AsOf(_) => polars_bail!( ComputeError: "asof join not supported for join on multiple keys" ), + JoinType::IEJoin(_) => { + unreachable!() + }, JoinType::Cross => { unreachable!() }, JoinType::Full => { - let names_left = selected_left.iter().map(|s| s.name()).collect::>(); + let names_left = selected_left + .iter() + .map(|s| s.name().clone()) + .collect::>(); args.coalesce = JoinCoalesce::KeepColumns; let suffix = args.suffix.clone(); let out = left_df._full_join_from_series(other, &lhs_keys, &rhs_keys, args); @@ -306,9 +334,9 @@ pub trait DataFrameJoinOps: IntoDf { if should_coalesce { Ok(_coalesce_full_join( out?, - &names_left, - drop_names.as_ref().unwrap(), - suffix.as_deref(), + names_left.as_slice(), + drop_names.as_slice(), + suffix.clone(), left_df, )) } else { @@ -321,7 +349,7 @@ pub trait DataFrameJoinOps: IntoDf { &rhs_keys, args, _verbose, - drop_names.as_deref(), + Some(drop_names), ), JoinType::Left => dispatch_left_right::left_join_from_series( left_df.clone(), @@ -330,7 +358,7 @@ pub trait DataFrameJoinOps: IntoDf { &rhs_keys, args, _verbose, - drop_names.as_deref(), + Some(drop_names), ), JoinType::Right => dispatch_left_right::right_join_from_series( left_df, @@ -339,7 +367,7 @@ pub trait DataFrameJoinOps: IntoDf { &rhs_keys, args, _verbose, - drop_names.as_deref(), + Some(drop_names), ), #[cfg(feature = "semi_anti_join")] JoinType::Anti | JoinType::Semi => self._join_impl( @@ -364,16 +392,12 @@ pub trait DataFrameJoinOps: IntoDf { /// left.inner_join(right, ["join_column_left"], ["join_column_right"]) /// } /// ``` - fn inner_join( + fn inner_join( &self, other: &DataFrame, - left_on: I, - right_on: I, - ) -> PolarsResult - where - I: IntoIterator, - S: AsRef, - { + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + ) -> PolarsResult { self.join(other, left_on, right_on, JoinArgs::new(JoinType::Inner)) } @@ -412,11 +436,12 @@ pub trait DataFrameJoinOps: IntoDf { /// | 100 | null | /// +-----------------+--------+ /// ``` - fn left_join(&self, other: &DataFrame, left_on: I, right_on: I) -> PolarsResult - where - I: IntoIterator, - S: AsRef, - { + fn left_join( + &self, + other: &DataFrame, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + ) -> PolarsResult { self.join(other, left_on, right_on, JoinArgs::new(JoinType::Left)) } @@ -430,11 +455,12 @@ pub trait DataFrameJoinOps: IntoDf { /// left.full_join(right, ["join_column_left"], ["join_column_right"]) /// } /// ``` - fn full_join(&self, other: &DataFrame, left_on: I, right_on: I) -> PolarsResult - where - I: IntoIterator, - S: AsRef, - { + fn full_join( + &self, + other: &DataFrame, + left_on: impl IntoIterator>, + right_on: impl IntoIterator>, + ) -> PolarsResult { self.join(other, left_on, right_on, JoinArgs::new(JoinType::Full)) } } @@ -447,7 +473,7 @@ trait DataFrameJoinOpsPrivate: IntoDf { s_right: &Series, args: JoinArgs, verbose: bool, - drop_names: Option<&[&str]>, + drop_names: Option>, ) -> PolarsResult { let left_df = self.to_df(); #[cfg(feature = "dtype-categorical")] @@ -475,7 +501,7 @@ trait DataFrameJoinOpsPrivate: IntoDf { ._take_unchecked_slice(join_tuples_right, true) }, ); - _finish_join(df_left, df_right, args.suffix.as_deref()) + _finish_join(df_left, df_right, args.suffix.clone()) } } diff --git a/crates/polars-ops/src/frame/mod.rs b/crates/polars-ops/src/frame/mod.rs index 93b2af3dd272..5691919c8861 100644 --- a/crates/polars-ops/src/frame/mod.rs +++ b/crates/polars-ops/src/frame/mod.rs @@ -96,13 +96,16 @@ pub trait DataFrameOps: IntoDf { ) -> PolarsResult { let df = self.to_df(); - let set: PlHashSet<&str> = - PlHashSet::from_iter(columns.unwrap_or_else(|| df.get_column_names())); + let set: PlHashSet<&str> = if let Some(columns) = columns { + PlHashSet::from_iter(columns) + } else { + PlHashSet::from_iter(df.iter().map(|s| s.name().as_str())) + }; let cols = POOL.install(|| { df.get_columns() .par_iter() - .map(|s| match set.contains(s.name()) { + .map(|s| match set.contains(s.name().as_str()) { true => s.to_dummies(separator, drop_first), false => Ok(s.clone().into_frame()), }) diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 7fea3564532e..d909b580f87b 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -1,4 +1,5 @@ mod positioning; +mod unpivot; use std::borrow::Cow; @@ -7,6 +8,8 @@ use polars_core::frame::group_by::expr::PhysicalAggExpr; use polars_core::prelude::*; use polars_core::utils::_split_offsets; use polars_core::{downcast_as_macro_arg_physical, POOL}; +use polars_utils::format_pl_smallstr; +pub use unpivot::UnpivotDF; const HASHMAP_INIT_SIZE: usize = 512; @@ -95,14 +98,11 @@ where I0: IntoIterator, I1: IntoIterator, I2: IntoIterator, - S0: AsRef, - S1: AsRef, - S2: AsRef, + S0: Into, + S1: Into, + S2: Into, { - let on = on - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect::>(); + let on = on.into_iter().map(Into::into).collect::>(); let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?; pivot_impl( pivot_df, @@ -134,20 +134,17 @@ where I0: IntoIterator, I1: IntoIterator, I2: IntoIterator, - S0: AsRef, - S1: AsRef, - S2: AsRef, + S0: Into, + S1: Into, + S2: Into, { - let on = on - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect::>(); + let on = on.into_iter().map(Into::into).collect::>(); let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?; pivot_impl( pivot_df, - &on, - &index, - &values, + on.as_slice(), + index.as_slice(), + values.as_slice(), agg_fn, sort_columns, true, @@ -162,39 +159,39 @@ where /// - At least one of `index` and `values` must be non-null. fn assign_remaining_columns( df: &DataFrame, - on: &[String], + on: &[PlSmallStr], index: Option, values: Option, -) -> PolarsResult<(Vec, Vec)> +) -> PolarsResult<(Vec, Vec)> where I1: IntoIterator, I2: IntoIterator, - S1: AsRef, - S2: AsRef, + S1: Into, + S2: Into, { match (index, values) { (Some(index), Some(values)) => { - let index = index.into_iter().map(|s| s.as_ref().to_string()).collect(); - let values = values.into_iter().map(|s| s.as_ref().to_string()).collect(); + let index = index.into_iter().map(Into::into).collect(); + let values = values.into_iter().map(Into::into).collect(); Ok((index, values)) }, (Some(index), None) => { - let index: Vec = index.into_iter().map(|s| s.as_ref().to_string()).collect(); + let index: Vec = index.into_iter().map(Into::into).collect(); let values = df .get_column_names() .into_iter() - .map(|s| s.to_string()) .filter(|c| !(index.contains(c) | on.contains(c))) + .cloned() .collect(); Ok((index, values)) }, (None, Some(values)) => { - let values: Vec = values.into_iter().map(|s| s.as_ref().to_string()).collect(); + let values: Vec = values.into_iter().map(Into::into).collect(); let index = df .get_column_names() .into_iter() - .map(|s| s.to_string()) .filter(|c| !(values.contains(c) | on.contains(c))) + .cloned() .collect(); Ok((index, values)) }, @@ -208,12 +205,12 @@ where fn pivot_impl( pivot_df: &DataFrame, // keys of the first group_by operation - on: &[String], + on: &[PlSmallStr], // these columns will be aggregated in the nested group_by - index: &[String], + index: &[PlSmallStr], // these columns will be used for a nested group_by // the rows of this nested group_by will be pivoted as header column values - values: &[String], + values: &[PlSmallStr], // aggregation function agg_fn: Option, sort_columns: bool, @@ -228,14 +225,14 @@ fn pivot_impl( }; if on.len() > 1 { let schema = Arc::new(pivot_df.schema()); - let binding = pivot_df.select_with_schema(on, &schema)?; + let binding = pivot_df.select_with_schema(on.iter().cloned(), &schema)?; let fields = binding.get_columns(); - let column = format!("{{\"{}\"}}", on.join("\",\"")); + let column = format_pl_smallstr!("{{\"{}\"}}", on.join("\",\"")); if schema.contains(column.as_str()) { polars_bail!(ComputeError: "cannot use column name {column} that \ already exists in the DataFrame. Please rename it prior to calling `pivot`.") } - let columns_struct = StructChunked::from_series(&column, fields) + let columns_struct = StructChunked::from_series(column.clone(), fields) .unwrap() .into_series(); let mut binding = pivot_df.clone(); @@ -264,9 +261,9 @@ fn pivot_impl( fn pivot_impl_single_column( pivot_df: &DataFrame, - index: &[String], - column: &str, - values: &[String], + index: &[PlSmallStr], + column: &PlSmallStr, + values: &[PlSmallStr], agg_fn: Option, sort_columns: bool, separator: Option<&str>, @@ -276,7 +273,7 @@ fn pivot_impl_single_column( let mut count = 0; let out: PolarsResult<()> = POOL.install(|| { let mut group_by = index.to_vec(); - group_by.push(column.to_string()); + group_by.push(column.clone()); let groups = pivot_df.group_by_stable(group_by)?.take_groups(); @@ -294,9 +291,13 @@ fn pivot_impl_single_column( let value_agg = unsafe { match &agg_fn { None => match value_col.len() > groups.len() { - true => polars_bail!(ComputeError: "found multiple elements in the same group, please specify an aggregation function"), + true => polars_bail!( + ComputeError: + "found multiple elements in the same group, \ + please specify an aggregation function" + ), false => value_col.agg_first(&groups), - } + }, Some(agg_fn) => match agg_fn { Sum => value_col.agg_sum(&groups), Min => value_col.agg_min(&groups), @@ -307,14 +308,14 @@ fn pivot_impl_single_column( Median => value_col.agg_median(&groups), Count => groups.group_count().into_series(), Expr(ref expr) => { - let name = expr.root_name()?; + let name = expr.root_name()?.clone(); let mut value_col = value_col.clone(); value_col.rename(name); let tmp_df = value_col.into_frame(); let mut aggregated = expr.evaluate(&tmp_df, &groups)?; - aggregated.rename(value_col_name); + aggregated.rename(value_col_name.clone()); aggregated - } + }, }, } }; diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index ec6f6eec4792..51761df873b5 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -73,7 +73,9 @@ pub(super) fn position_aggregates( .map(|(i, opt_name)| { let offset = i * n_rows; let avs = &buf[offset..offset + n_rows]; - let name = opt_name.unwrap_or("null"); + let name = opt_name + .map(PlSmallStr::from_str) + .unwrap_or_else(|| PlSmallStr::from_static("null")); let out = match &phys_type { #[cfg(feature = "dtype-struct")] DataType::Struct(_) => { @@ -166,7 +168,9 @@ where .map(|(i, opt_name)| { let offset = i * n_rows; let opt_values = &buf[offset..offset + n_rows]; - let name = opt_name.unwrap_or("null"); + let name = opt_name + .map(PlSmallStr::from_str) + .unwrap_or_else(|| PlSmallStr::from_static("null")); let out = ChunkedArray::::from_slice_options(name, opt_values).into_series(); unsafe { out.cast_unchecked(logical_type).unwrap() } }) @@ -293,7 +297,7 @@ pub(super) fn compute_col_idx( } fn compute_row_index<'a, T>( - index: &[String], + index: &[PlSmallStr], index_agg_physical: &'a ChunkedArray, count: usize, logical_type: &DataType, @@ -331,7 +335,7 @@ where .map(|(k, _)| Option::>::peel_total_ord(k)) .collect::>() .into_series(); - s.rename(&index[0]); + s.rename(index[0].clone()); let s = restore_logical_type(&s, logical_type); Some(vec![s]) }, @@ -342,7 +346,7 @@ where } fn compute_row_index_struct( - index: &[String], + index: &[PlSmallStr], index_agg: &Series, index_agg_physical: &BinaryOffsetChunked, count: usize, @@ -377,7 +381,7 @@ fn compute_row_index_struct( // SAFETY: `unique_indices` is filled with elements between // 0 and `index_agg.len() - 1`. let mut s = unsafe { index_agg.take_slice_unchecked(&unique_indices) }; - s.rename(&index[0]); + s.rename(index[0].clone()); Some(vec![s]) }, _ => None, @@ -389,7 +393,7 @@ fn compute_row_index_struct( // TODO! Also create a specialized version for numerics. pub(super) fn compute_row_idx( pivot_df: &DataFrame, - index: &[String], + index: &[PlSmallStr], groups: &GroupsProxy, count: usize, ) -> PolarsResult<(Vec, usize, Option>)> { @@ -452,7 +456,7 @@ pub(super) fn compute_row_idx( let row_index = match count { 0 => { let s = Series::new( - &index[0], + index[0].clone(), row_to_idx.into_iter().map(|(k, _)| k).collect::>(), ); let s = restore_logical_type(&s, index_s.dtype()); @@ -465,9 +469,11 @@ pub(super) fn compute_row_idx( }, } } else { - let binding = pivot_df.select(index)?; + let binding = pivot_df.select(index.iter().cloned())?; let fields = binding.get_columns(); - let index_struct_series = StructChunked::from_series("placeholder", fields)?.into_series(); + let index_struct_series = + StructChunked::from_series(PlSmallStr::from_static("placeholder"), fields)? + .into_series(); let index_agg = unsafe { index_struct_series.agg_first(groups) }; let index_agg_physical = index_agg.to_physical_repr(); let ca = index_agg_physical.struct_()?; diff --git a/crates/polars-ops/src/frame/pivot/unpivot.rs b/crates/polars-ops/src/frame/pivot/unpivot.rs new file mode 100644 index 000000000000..a9255bdede0e --- /dev/null +++ b/crates/polars-ops/src/frame/pivot/unpivot.rs @@ -0,0 +1,289 @@ +use arrow::array::{MutableArray, MutablePlString}; +use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked; +use polars_core::datatypes::{DataType, PlSmallStr}; +use polars_core::frame::DataFrame; +use polars_core::prelude::{IntoVec, Series, UnpivotArgsIR}; +use polars_core::utils::try_get_supertype; +use polars_error::{polars_err, PolarsResult}; +use polars_utils::aliases::PlHashSet; + +use crate::frame::IntoDf; + +pub trait UnpivotDF: IntoDf { + /// Unpivot a `DataFrame` from wide to long format. + /// + /// # Example + /// + /// # Arguments + /// + /// * `on` - String slice that represent the columns to use as value variables. + /// * `index` - String slice that represent the columns to use as id variables. + /// + /// If `on` is empty all columns that are not in `index` will be used. + /// + /// ```ignore + /// # use polars_core::prelude::*; + /// let df = df!("A" => &["a", "b", "a"], + /// "B" => &[1, 3, 5], + /// "C" => &[10, 11, 12], + /// "D" => &[2, 4, 6] + /// )?; + /// + /// let unpivoted = df.unpivot(&["A", "B"], &["C", "D"])?; + /// println!("{:?}", df); + /// println!("{:?}", unpivoted); + /// # Ok::<(), PolarsError>(()) + /// ``` + /// Outputs: + /// ```text + /// +-----+-----+-----+-----+ + /// | A | B | C | D | + /// | --- | --- | --- | --- | + /// | str | i32 | i32 | i32 | + /// +=====+=====+=====+=====+ + /// | "a" | 1 | 10 | 2 | + /// +-----+-----+-----+-----+ + /// | "b" | 3 | 11 | 4 | + /// +-----+-----+-----+-----+ + /// | "a" | 5 | 12 | 6 | + /// +-----+-----+-----+-----+ + /// + /// +-----+-----+----------+-------+ + /// | A | B | variable | value | + /// | --- | --- | --- | --- | + /// | str | i32 | str | i32 | + /// +=====+=====+==========+=======+ + /// | "a" | 1 | "C" | 10 | + /// +-----+-----+----------+-------+ + /// | "b" | 3 | "C" | 11 | + /// +-----+-----+----------+-------+ + /// | "a" | 5 | "C" | 12 | + /// +-----+-----+----------+-------+ + /// | "a" | 1 | "D" | 2 | + /// +-----+-----+----------+-------+ + /// | "b" | 3 | "D" | 4 | + /// +-----+-----+----------+-------+ + /// | "a" | 5 | "D" | 6 | + /// +-----+-----+----------+-------+ + /// ``` + fn unpivot(&self, on: I, index: J) -> PolarsResult + where + I: IntoVec, + J: IntoVec, + { + let index = index.into_vec(); + let on = on.into_vec(); + self.unpivot2(UnpivotArgsIR { + on, + index, + ..Default::default() + }) + } + + /// Similar to unpivot, but without generics. This may be easier if you want to pass + /// an empty `index` or empty `on`. + fn unpivot2(&self, args: UnpivotArgsIR) -> PolarsResult { + let self_ = self.to_df(); + let index = args.index; + let mut on = args.on; + + let variable_name = args + .variable_name + .unwrap_or_else(|| PlSmallStr::from_static("variable")); + let value_name = args + .value_name + .unwrap_or_else(|| PlSmallStr::from_static("value")); + + if self_.get_columns().is_empty() { + return DataFrame::new(vec![ + Series::new_empty(variable_name, &DataType::String), + Series::new_empty(value_name, &DataType::Null), + ]); + } + + let len = self_.height(); + + // if value vars is empty we take all columns that are not in id_vars. + if on.is_empty() { + // return empty frame if there are no columns available to use as value vars + if index.len() == self_.width() { + let variable_col = Series::new_empty(variable_name, &DataType::String); + let value_col = Series::new_empty(value_name, &DataType::Null); + + let mut out = self_.select(index).unwrap().clear().take_columns(); + out.push(variable_col); + out.push(value_col); + + return Ok(unsafe { DataFrame::new_no_checks(out) }); + } + + let index_set = PlHashSet::from_iter(index.iter().cloned()); + on = self_ + .get_columns() + .iter() + .filter_map(|s| { + if index_set.contains(s.name()) { + None + } else { + Some(s.name().clone()) + } + }) + .collect(); + } + + // values will all be placed in single column, so we must find their supertype + let schema = self_.schema(); + let mut iter = on + .iter() + .map(|v| schema.get(v).ok_or_else(|| polars_err!(col_not_found = v))); + let mut st = iter.next().unwrap()?.clone(); + for dt in iter { + st = try_get_supertype(&st, dt?)?; + } + + // The column name of the variable that is unpivoted + let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); + // prepare ids + let ids_ = self_.select_with_schema_unchecked(index, &schema)?; + let mut ids = ids_.clone(); + if ids.width() > 0 { + for _ in 0..on.len() - 1 { + ids.vstack_mut_unchecked(&ids_) + } + } + ids.as_single_chunk_par(); + drop(ids_); + + let mut values = Vec::with_capacity(on.len()); + let columns = self_.get_columns(); + + for value_column_name in &on { + variable_col.extend_constant(len, Some(value_column_name.as_str())); + // ensure we go via the schema so we are O(1) + // self.column() is linear + // together with this loop that would make it O^2 over `on` + let (pos, _name, _dtype) = schema.try_get_full(value_column_name)?; + let col = &columns[pos]; + let value_col = col.cast(&st).map_err( + |_| polars_err!(InvalidOperation: "'unpivot' not supported for dtype: {}", col.dtype()), + )?; + values.extend_from_slice(value_col.chunks()) + } + let values_arr = concatenate_owned_unchecked(&values)?; + // SAFETY: + // The give dtype is correct + let values = + unsafe { Series::from_chunks_and_dtype_unchecked(value_name, vec![values_arr], &st) }; + + let variable_col = variable_col.as_box(); + // SAFETY: + // The given dtype is correct + let variables = unsafe { + Series::from_chunks_and_dtype_unchecked( + variable_name, + vec![variable_col], + &DataType::String, + ) + }; + + ids.hstack_mut(&[variables, values])?; + + Ok(ids) + } +} + +impl UnpivotDF for DataFrame {} + +#[cfg(test)] +mod test { + use polars_core::df; + use polars_core::utils::Container; + + use super::*; + + #[test] + fn test_unpivot() -> PolarsResult<()> { + let df = df!("A" => &["a", "b", "a"], + "B" => &[1, 3, 5], + "C" => &[10, 11, 12], + "D" => &[2, 4, 6] + ) + .unwrap(); + + // Specify on and index + let unpivoted = df.unpivot(["C", "D"], ["A", "B"])?; + assert_eq!( + unpivoted.get_column_names(), + &["A", "B", "variable", "value"] + ); + assert_eq!( + Vec::from(unpivoted.column("value")?.i32()?), + &[Some(10), Some(11), Some(12), Some(2), Some(4), Some(6)] + ); + + // Specify custom column names + let args = UnpivotArgsIR { + on: vec!["C".into(), "D".into()], + index: vec!["A".into(), "B".into()], + variable_name: Some("custom_variable".into()), + value_name: Some("custom_value".into()), + }; + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!( + unpivoted.get_column_names(), + &["A", "B", "custom_variable", "custom_value"] + ); + + // Specify neither on nor index + let args = UnpivotArgsIR { + on: vec![], + index: vec![], + ..Default::default() + }; + + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!(unpivoted.get_column_names(), &["variable", "value"]); + let value = unpivoted.column("value")?; + // String because of supertype + let value = value.str()?; + let value = value.into_no_null_iter().collect::>(); + assert_eq!( + value, + &["a", "b", "a", "1", "3", "5", "10", "11", "12", "2", "4", "6"] + ); + + // Specify index but not on + let args = UnpivotArgsIR { + on: vec![], + index: vec!["A".into()], + ..Default::default() + }; + + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!(unpivoted.get_column_names(), &["A", "variable", "value"]); + let value = unpivoted.column("value")?; + let value = value.i32()?; + let value = value.into_no_null_iter().collect::>(); + assert_eq!(value, &[1, 3, 5, 10, 11, 12, 2, 4, 6]); + let variable = unpivoted.column("variable")?; + let variable = variable.str()?; + let variable = variable.into_no_null_iter().collect::>(); + assert_eq!(variable, &["B", "B", "B", "C", "C", "C", "D", "D", "D"]); + assert!(unpivoted.column("A").is_ok()); + + // Specify all columns in index + let args = UnpivotArgsIR { + on: vec![], + index: vec!["A".into(), "B".into(), "C".into(), "D".into()], + ..Default::default() + }; + let unpivoted = df.unpivot2(args).unwrap(); + assert_eq!( + unpivoted.get_column_names(), + &["A", "B", "C", "D", "variable", "value"] + ); + assert_eq!(unpivoted.len(), 0); + + Ok(()) + } +} diff --git a/crates/polars-ops/src/lib.rs b/crates/polars-ops/src/lib.rs index 00d10e87c76c..5889f915ef3d 100644 --- a/crates/polars-ops/src/lib.rs +++ b/crates/polars-ops/src/lib.rs @@ -1,7 +1,6 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![cfg_attr(feature = "nightly", feature(unicode_internals))] #![cfg_attr(feature = "nightly", allow(internal_features))] -extern crate core; pub mod chunked_array; #[cfg(feature = "pivot")] diff --git a/crates/polars-ops/src/prelude.rs b/crates/polars-ops/src/prelude.rs index 1f0717945b49..2353afaefbc8 100644 --- a/crates/polars-ops/src/prelude.rs +++ b/crates/polars-ops/src/prelude.rs @@ -5,5 +5,7 @@ pub use crate::chunked_array::*; #[cfg(feature = "merge_sorted")] pub use crate::frame::_merge_sorted_dfs; pub use crate::frame::join::*; +#[cfg(feature = "pivot")] +pub use crate::frame::pivot::UnpivotDF; pub use crate::frame::{DataFrameJoinOps, DataFrameOps}; pub use crate::series::*; diff --git a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs index d507a1fcf20c..b341cab65b87 100644 --- a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs +++ b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs @@ -20,7 +20,7 @@ use std::hash::Hash; use std::marker::PhantomData; -use polars_core::export::ahash::RandomState; +use polars_utils::aliases::PlRandomStateQuality; /// The greater is P, the smaller the error. const HLL_P: usize = 14_usize; @@ -54,7 +54,7 @@ where /// shared across cluster, this SEED will have to be consistent across all /// parties otherwise we might have corruption. So ideally for later this seed /// shall be part of the serialized form (or stay unchanged across versions). -const SEED: RandomState = RandomState::with_seeds( +const SEED: PlRandomStateQuality = PlRandomStateQuality::with_seeds( 0x885f6cab121d01a3_u64, 0x71e4379f2976ad8f_u64, 0xbf30173dd28a8816_u64, @@ -81,9 +81,6 @@ where } } - /// choice of hash function: ahash is already an dependency - /// and it fits the requirements of being a 64bit hash with - /// reasonable performance. #[inline] fn hash_value(&self, obj: &T) -> u64 { SEED.hash_one(obj) diff --git a/crates/polars-ops/src/series/ops/approx_unique.rs b/crates/polars-ops/src/series/ops/approx_unique.rs index 31093e06b77a..ab0ea5db8966 100644 --- a/crates/polars-ops/src/series/ops/approx_unique.rs +++ b/crates/polars-ops/src/series/ops/approx_unique.rs @@ -17,7 +17,7 @@ where ca.iter().for_each(|item| hllp.add(&item.to_total_ord())); let c = hllp.count() as IdxSize; - Ok(Series::new(ca.name(), &[c])) + Ok(Series::new(ca.name().clone(), &[c])) } fn dispatcher(s: &Series) -> PolarsResult { @@ -59,7 +59,7 @@ fn dispatcher(s: &Series) -> PolarsResult { /// /// use polars_core::prelude::*; /// -/// let s = Series::new("s", [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]); +/// let s = Series::new("s".into(), [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]); /// /// let approx_count = approx_n_unique(&s).unwrap(); /// println!("{}", approx_count); diff --git a/crates/polars-ops/src/series/ops/business.rs b/crates/polars-ops/src/series/ops/business.rs index 17090e9a0b19..eff3e2e8c0ba 100644 --- a/crates/polars-ops/src/series/ops/business.rs +++ b/crates/polars-ops/src/series/ops/business.rs @@ -55,7 +55,7 @@ pub fn business_day_count( ) }) } else { - Int32Chunked::full_null(start_dates.name(), start_dates.len()) + Int32Chunked::full_null(start_dates.name().clone(), start_dates.len()) } }, (1, _) => { @@ -70,7 +70,7 @@ pub fn business_day_count( ) }) } else { - Int32Chunked::full_null(start_dates.name(), end_dates.len()) + Int32Chunked::full_null(start_dates.name().clone(), end_dates.len()) } }, _ => binary_elementwise_values(start_dates, end_dates, |start_date, end_date| { @@ -223,7 +223,7 @@ pub fn add_business_days( )) })? } else { - Int32Chunked::full_null(start_dates.name(), start_dates.len()) + Int32Chunked::full_null(start_dates.name().clone(), start_dates.len()) } }, (1, _) => { @@ -241,7 +241,7 @@ pub fn add_business_days( ) }) } else { - Int32Chunked::full_null(start_dates.name(), n.len()) + Int32Chunked::full_null(start_dates.name().clone(), n.len()) } }, _ => try_binary_elementwise(start_dates, n, |opt_start_date, opt_n| { diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs index bd498f5088e6..dab529796ddb 100644 --- a/crates/polars-ops/src/series/ops/cum_agg.rs +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -74,7 +74,7 @@ where false => ca.iter().scan(init, det_max).collect_trusted(), true => ca.iter().rev().scan(init, det_max).collect_reversed(), }; - out.with_name(ca.name()) + out.with_name(ca.name().clone()) } fn cum_min_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray @@ -87,7 +87,7 @@ where false => ca.iter().scan(init, det_min).collect_trusted(), true => ca.iter().rev().scan(init, det_min).collect_reversed(), }; - out.with_name(ca.name()) + out.with_name(ca.name().clone()) } fn cum_sum_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray @@ -100,7 +100,7 @@ where false => ca.iter().scan(init, det_sum).collect_trusted(), true => ca.iter().rev().scan(init, det_sum).collect_reversed(), }; - out.with_name(ca.name()) + out.with_name(ca.name().clone()) } fn cum_prod_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray @@ -113,7 +113,7 @@ where false => ca.iter().scan(init, det_prod).collect_trusted(), true => ca.iter().rev().scan(init, det_prod).collect_reversed(), }; - out.with_name(ca.name()) + out.with_name(ca.name().clone()) } /// Get an array with the cumulative product computed at every element. @@ -211,7 +211,7 @@ pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult { pub fn cum_count(s: &Series, reverse: bool) -> PolarsResult { let mut out = if s.null_count() == 0 { // Fast paths for no nulls - cum_count_no_nulls(s.name(), s.len(), reverse) + cum_count_no_nulls(s.name().clone(), s.len(), reverse) } else { let ca = s.is_not_null(); let out: IdxCa = if reverse { @@ -242,7 +242,7 @@ pub fn cum_count(s: &Series, reverse: bool) -> PolarsResult { Ok(out) } -fn cum_count_no_nulls(name: &str, len: usize, reverse: bool) -> Series { +fn cum_count_no_nulls(name: PlSmallStr, len: usize, reverse: bool) -> Series { let start = 1 as IdxSize; let end = len as IdxSize + 1; let ca: NoNull = if reverse { diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index a999fac2a3a0..2deb6dfeb52f 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -1,16 +1,17 @@ use polars_core::prelude::*; +use polars_utils::format_pl_smallstr; fn map_cats( s: &Series, - labels: &[String], + labels: &[PlSmallStr], sorted_breaks: &[f64], left_closed: bool, include_breaks: bool, ) -> PolarsResult { - let out_name = "category"; + let out_name = PlSmallStr::from_static("category"); // Create new categorical and pre-register labels for consistent categorical indexes. - let mut bld = CategoricalChunkedBuilder::new(out_name, s.len(), Default::default()); + let mut bld = CategoricalChunkedBuilder::new(out_name.clone(), s.len(), Default::default()); for label in labels { bld.register_value(label); } @@ -33,7 +34,10 @@ fn map_cats( // returned a dataframe. That included a column of the right endpoint of the interval. So we // return a struct series instead which can be turned into a dataframe later. let right_ends = [sorted_breaks, &[f64::INFINITY]].concat(); - let mut brk_vals = PrimitiveChunkedBuilder::::new("breakpoint", s.len()); + let mut brk_vals = PrimitiveChunkedBuilder::::new( + PlSmallStr::from_static("breakpoint"), + s.len(), + ); s_iter .map(|opt| { opt.filter(|x| !x.is_nan()).map(|x| { @@ -74,7 +78,7 @@ fn map_cats( } } -pub fn compute_labels(breaks: &[f64], left_closed: bool) -> PolarsResult> { +pub fn compute_labels(breaks: &[f64], left_closed: bool) -> PolarsResult> { let lo = std::iter::once(&f64::NEG_INFINITY).chain(breaks.iter()); let hi = breaks.iter().chain(std::iter::once(&f64::INFINITY)); @@ -82,9 +86,9 @@ pub fn compute_labels(breaks: &[f64], left_closed: bool) -> PolarsResult PolarsResult, - labels: Option>, + labels: Option>, left_closed: bool, include_breaks: bool, ) -> PolarsResult { @@ -120,7 +124,7 @@ pub fn cut( pub fn qcut( s: &Series, probs: Vec, - labels: Option>, + labels: Option>, left_closed: bool, allow_duplicates: bool, include_breaks: bool, @@ -169,9 +173,9 @@ mod test { use super::map_cats; - let s = Series::new("x", &[1, 2, 3, 4, 5]); + let s = Series::new("x".into(), &[1, 2, 3, 4, 5]); - let labels = &["a", "b", "c"].map(str::to_owned); + let labels = &["a", "b", "c"].map(PlSmallStr::from_static); let breaks = &[2.0, 4.0]; let left_closed = false; diff --git a/crates/polars-ops/src/series/ops/duration.rs b/crates/polars-ops/src/series/ops/duration.rs new file mode 100644 index 000000000000..1d5868260e64 --- /dev/null +++ b/crates/polars-ops/src/series/ops/duration.rs @@ -0,0 +1,91 @@ +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS, SECONDS_IN_DAY}; +use polars_core::datatypes::{AnyValue, DataType, TimeUnit}; +use polars_core::prelude::Series; +use polars_error::PolarsResult; + +pub fn impl_duration(s: &[Series], time_unit: TimeUnit) -> PolarsResult { + if s.iter().any(|s| s.is_empty()) { + return Ok(Series::new_empty( + s[0].name().clone(), + &DataType::Duration(time_unit), + )); + } + + // TODO: Handle overflow for UInt64 + let weeks = s[0].cast(&DataType::Int64).unwrap(); + let days = s[1].cast(&DataType::Int64).unwrap(); + let hours = s[2].cast(&DataType::Int64).unwrap(); + let minutes = s[3].cast(&DataType::Int64).unwrap(); + let seconds = s[4].cast(&DataType::Int64).unwrap(); + let mut milliseconds = s[5].cast(&DataType::Int64).unwrap(); + let mut microseconds = s[6].cast(&DataType::Int64).unwrap(); + let mut nanoseconds = s[7].cast(&DataType::Int64).unwrap(); + + let is_scalar = |s: &Series| s.len() == 1; + let is_zero_scalar = |s: &Series| is_scalar(s) && s.get(0).unwrap() == AnyValue::Int64(0); + + // Process subseconds + let max_len = s.iter().map(|s| s.len()).max().unwrap(); + let mut duration = match time_unit { + TimeUnit::Microseconds => { + if is_scalar(µseconds) { + microseconds = microseconds.new_from_index(0, max_len); + } + if !is_zero_scalar(&nanoseconds) { + microseconds = (microseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000)))?; + } + if !is_zero_scalar(&milliseconds) { + microseconds = (microseconds + (milliseconds * 1_000))?; + } + microseconds + }, + TimeUnit::Nanoseconds => { + if is_scalar(&nanoseconds) { + nanoseconds = nanoseconds.new_from_index(0, max_len); + } + if !is_zero_scalar(µseconds) { + nanoseconds = (nanoseconds + (microseconds * 1_000))?; + } + if !is_zero_scalar(&milliseconds) { + nanoseconds = (nanoseconds + (milliseconds * 1_000_000))?; + } + nanoseconds + }, + TimeUnit::Milliseconds => { + if is_scalar(&milliseconds) { + milliseconds = milliseconds.new_from_index(0, max_len); + } + if !is_zero_scalar(&nanoseconds) { + milliseconds = (milliseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000_000)))?; + } + if !is_zero_scalar(µseconds) { + milliseconds = (milliseconds + (microseconds.wrapping_trunc_div_scalar(1_000)))?; + } + milliseconds + }, + }; + + // Process other duration specifiers + let multiplier = match time_unit { + TimeUnit::Nanoseconds => NANOSECONDS, + TimeUnit::Microseconds => MICROSECONDS, + TimeUnit::Milliseconds => MILLISECONDS, + }; + if !is_zero_scalar(&seconds) { + duration = (duration + seconds * multiplier)?; + } + if !is_zero_scalar(&minutes) { + duration = (duration + minutes * (multiplier * 60))?; + } + if !is_zero_scalar(&hours) { + duration = (duration + hours * (multiplier * 60 * 60))?; + } + if !is_zero_scalar(&days) { + duration = (duration + days * (multiplier * SECONDS_IN_DAY))?; + } + if !is_zero_scalar(&weeks) { + duration = (duration + weeks * (multiplier * SECONDS_IN_DAY * 7))?; + } + + duration.cast(&DataType::Duration(time_unit)) +} diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs index 22b99a04a892..d6fa9c31a044 100644 --- a/crates/polars-ops/src/series/ops/ewm.rs +++ b/crates/polars-ops/src/series/ops/ewm.rs @@ -21,7 +21,7 @@ pub fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { options.min_periods, options.ignore_nulls, ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) }, DataType::Float64 => { let xs = s.f64().unwrap(); @@ -32,7 +32,7 @@ pub fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { options.min_periods, options.ignore_nulls, ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) }, _ => ewm_mean(&s.cast(&DataType::Float64)?, options), } @@ -51,7 +51,7 @@ pub fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { options.min_periods, options.ignore_nulls, ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) }, DataType::Float64 => { let xs = s.f64().unwrap(); @@ -63,7 +63,7 @@ pub fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { options.min_periods, options.ignore_nulls, ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) }, _ => ewm_std(&s.cast(&DataType::Float64)?, options), } @@ -82,7 +82,7 @@ pub fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { options.min_periods, options.ignore_nulls, ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) }, DataType::Float64 => { let xs = s.f64().unwrap(); @@ -94,7 +94,7 @@ pub fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { options.min_periods, options.ignore_nulls, ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) + Series::try_from((s.name().clone(), Box::new(result) as ArrayRef)) }, _ => ewm_var(&s.cast(&DataType::Float64)?, options), } diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs index 9ae0db056ae5..fe79710ab9bf 100644 --- a/crates/polars-ops/src/series/ops/ewm_by.rs +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -135,7 +135,7 @@ where let validity = binary_concatenate_validities(times, values); arr = arr.with_validity_typed(validity); } - ChunkedArray::with_chunk(values.name(), arr) + ChunkedArray::with_chunk(values.name().clone(), arr) } /// Fastpath if `times` is known to already be sorted. @@ -184,7 +184,7 @@ where let validity = binary_concatenate_validities(times, values); arr = arr.with_validity_typed(validity); } - ChunkedArray::with_chunk(values.name(), arr) + ChunkedArray::with_chunk(values.name().clone(), arr) } fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 { diff --git a/crates/polars-ops/src/series/ops/fused.rs b/crates/polars-ops/src/series/ops/fused.rs index 86c8b5656fe0..16b06f76c479 100644 --- a/crates/polars-ops/src/series/ops/fused.rs +++ b/crates/polars-ops/src/series/ops/fused.rs @@ -38,7 +38,7 @@ fn fma_ca( .zip(b.downcast_iter()) .zip(c.downcast_iter()) .map(|((a, b), c)| fma_arr(a, b, c)); - ChunkedArray::from_chunk_iter(a.name(), chunks) + ChunkedArray::from_chunk_iter(a.name().clone(), chunks) } pub fn fma_series(a: &Series, b: &Series, c: &Series) -> Series { @@ -89,7 +89,7 @@ fn fsm_ca( .zip(b.downcast_iter()) .zip(c.downcast_iter()) .map(|((a, b), c)| fsm_arr(a, b, c)); - ChunkedArray::from_chunk_iter(a.name(), chunks) + ChunkedArray::from_chunk_iter(a.name().clone(), chunks) } pub fn fsm_series(a: &Series, b: &Series, c: &Series) -> Series { @@ -139,7 +139,7 @@ fn fms_ca( .zip(b.downcast_iter()) .zip(c.downcast_iter()) .map(|((a, b), c)| fms_arr(a, b, c)); - ChunkedArray::from_chunk_iter(a.name(), chunks) + ChunkedArray::from_chunk_iter(a.name().clone(), chunks) } pub fn fms_series(a: &Series, b: &Series, c: &Series) -> Series { diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index c8e3488aab93..4412e2aa21d1 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -4,25 +4,25 @@ use polars_core::prelude::*; pub fn max_horizontal(s: &[Series]) -> PolarsResult> { let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.max_horizontal() - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) + .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn min_horizontal(s: &[Series]) -> PolarsResult> { let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.min_horizontal() - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) + .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn sum_horizontal(s: &[Series]) -> PolarsResult> { let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.sum_horizontal(NullStrategy::Ignore) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) + .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn mean_horizontal(s: &[Series]) -> PolarsResult> { let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.mean_horizontal(NullStrategy::Ignore) - .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) + .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name().clone()))) } pub fn coalesce_series(s: &[Series]) -> PolarsResult { diff --git a/crates/polars-ops/src/series/ops/int_range.rs b/crates/polars-ops/src/series/ops/int_range.rs index 4c68b2280635..5e5a3d419acb 100644 --- a/crates/polars-ops/src/series/ops/int_range.rs +++ b/crates/polars-ops/src/series/ops/int_range.rs @@ -5,7 +5,7 @@ pub fn new_int_range( start: T::Native, end: T::Native, step: i64, - name: &str, + name: PlSmallStr, ) -> PolarsResult where T: PolarsIntegerType, diff --git a/crates/polars-ops/src/series/ops/interpolation/interpolate.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate.rs index 11af19651fe0..36d9dc12e556 100644 --- a/crates/polars-ops/src/series/ops/interpolation/interpolate.rs +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate.rs @@ -103,9 +103,9 @@ where out.into(), Some(validity.into()), ); - ChunkedArray::with_chunk(chunked_arr.name(), array) + ChunkedArray::with_chunk(chunked_arr.name().clone(), array) } else { - ChunkedArray::from_vec(chunked_arr.name(), out) + ChunkedArray::from_vec(chunked_arr.name().clone(), out) } } @@ -211,7 +211,7 @@ mod test { #[test] fn test_interpolate() { - let ca = UInt32Chunked::new("", &[Some(1), None, None, Some(4), Some(5)]); + let ca = UInt32Chunked::new("".into(), &[Some(1), None, None, Some(4), Some(5)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); let out = out.f64().unwrap(); assert_eq!( @@ -219,7 +219,7 @@ mod test { &[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] ); - let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5)]); + let ca = UInt32Chunked::new("".into(), &[None, Some(1), None, None, Some(4), Some(5)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); let out = out.f64().unwrap(); assert_eq!( @@ -227,7 +227,10 @@ mod test { &[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] ); - let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5), None]); + let ca = UInt32Chunked::new( + "".into(), + &[None, Some(1), None, None, Some(4), Some(5), None], + ); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); let out = out.f64().unwrap(); assert_eq!( @@ -242,7 +245,10 @@ mod test { None ] ); - let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5), None]); + let ca = UInt32Chunked::new( + "".into(), + &[None, Some(1), None, None, Some(4), Some(5), None], + ); let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest); let out = out.u32().unwrap(); assert_eq!( @@ -253,7 +259,7 @@ mod test { #[test] fn test_interpolate_decreasing_unsigned() { - let ca = UInt32Chunked::new("", &[Some(4), None, None, Some(1)]); + let ca = UInt32Chunked::new("".into(), &[Some(4), None, None, Some(1)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); let out = out.f64().unwrap(); assert_eq!( @@ -265,7 +271,7 @@ mod test { #[test] fn test_interpolate2() { let ca = Float32Chunked::new( - "", + "".into(), &[ Some(4653f32), None, diff --git a/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs index 674cbab514e9..06a8378055da 100644 --- a/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs @@ -87,7 +87,7 @@ fn interpolate_impl_by_sorted( ) -> PolarsResult> where T: PolarsNumericType, - F: PolarsIntegerType, + F: PolarsNumericType, I: Fn(T::Native, T::Native, &[F::Native], &mut Vec), { // This implementation differs from pandas as that boundary None's are not removed. @@ -155,9 +155,9 @@ where out.into(), Some(validity.into()), ); - Ok(ChunkedArray::with_chunk(chunked_arr.name(), array)) + Ok(ChunkedArray::with_chunk(chunked_arr.name().clone(), array)) } else { - Ok(ChunkedArray::from_vec(chunked_arr.name(), out)) + Ok(ChunkedArray::from_vec(chunked_arr.name().clone(), out)) } } @@ -169,7 +169,7 @@ fn interpolate_impl_by( ) -> PolarsResult> where T: PolarsNumericType, - F: PolarsIntegerType, + F: PolarsNumericType, I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]), { // This implementation differs from pandas as that boundary None's are not removed. @@ -257,9 +257,9 @@ where out.into(), Some(validity.into()), ); - Ok(ChunkedArray::with_chunk(ca_sorted.name(), array)) + Ok(ChunkedArray::with_chunk(ca_sorted.name().clone(), array)) } else { - Ok(ChunkedArray::from_vec(ca_sorted.name(), out)) + Ok(ChunkedArray::from_vec(ca_sorted.name().clone(), out)) } } @@ -273,7 +273,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu ) -> PolarsResult where T: PolarsNumericType, - F: PolarsIntegerType, + F: PolarsNumericType, ChunkedArray: IntoSeries, { if is_sorted { @@ -290,6 +290,18 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu } match (s.dtype(), by.dtype()) { + (DataType::Float64, DataType::Float64) => { + func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::Float32) => { + func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Float64) => { + func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Float32) => { + func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted) + }, (DataType::Float64, DataType::Int64) => { func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted) }, @@ -326,7 +338,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu _ => { polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \ Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \ - UInt64, or UInt32") + UInt64, UInt32, Float32 or Float64") }, } } diff --git a/crates/polars-ops/src/series/ops/is_first_distinct.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs index d3440340d9a7..4fdb10e162c3 100644 --- a/crates/polars-ops/src/series/ops/is_first_distinct.rs +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -20,7 +20,7 @@ where .collect_trusted() }); - BooleanChunked::from_chunk_iter(ca.name(), chunks) + BooleanChunked::from_chunk_iter(ca.name().clone(), chunks) } fn is_first_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { @@ -31,7 +31,7 @@ fn is_first_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { .collect_trusted() }); - BooleanChunked::from_chunk_iter(ca.name(), chunks) + BooleanChunked::from_chunk_iter(ca.name().clone(), chunks) } fn is_first_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { @@ -69,7 +69,7 @@ fn is_first_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { } } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); - BooleanChunked::with_chunk(ca.name(), arr) + BooleanChunked::with_chunk(ca.name().clone(), arr) } #[cfg(feature = "dtype-struct")] @@ -85,7 +85,7 @@ fn is_first_distinct_struct(s: &Series) -> PolarsResult { } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); - Ok(BooleanChunked::with_chunk(s.name(), arr)) + Ok(BooleanChunked::with_chunk(s.name().clone(), arr)) } fn is_first_distinct_list(ca: &ListChunked) -> PolarsResult { @@ -100,15 +100,15 @@ fn is_first_distinct_list(ca: &ListChunked) -> PolarsResult { } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); - Ok(BooleanChunked::with_chunk(ca.name(), arr)) + Ok(BooleanChunked::with_chunk(ca.name().clone(), arr)) } pub fn is_first_distinct(s: &Series) -> PolarsResult { // fast path. if s.len() == 0 { - return Ok(BooleanChunked::full_null(s.name(), 0)); + return Ok(BooleanChunked::full_null(s.name().clone(), 0)); } else if s.len() == 1 { - return Ok(BooleanChunked::new(s.name(), &[true])); + return Ok(BooleanChunked::new(s.name().clone(), &[true])); } let s = s.to_physical_repr(); diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index 7a4bdfa3f495..0e3f847307ed 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -5,7 +5,7 @@ use polars_core::prelude::*; use polars_core::utils::{try_get_supertype, CustomIterTools}; use polars_core::with_match_physical_numeric_polars_type; #[cfg(feature = "dtype-categorical")] -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_in_helper_ca<'a, T>( @@ -25,7 +25,10 @@ where } }) }); - Ok(unary_elementwise_values(ca, |val| set.contains(&val.to_total_ord())).with_name(ca.name())) + Ok( + unary_elementwise_values(ca, |val| set.contains(&val.to_total_ord())) + .with_name(ca.name().clone()), + ) } fn is_in_helper<'a, T>(ca: &'a ChunkedArray, other: &Series) -> PolarsResult @@ -70,7 +73,7 @@ where .collect_trusted() } }; - ca.rename(ca_in.name()); + ca.rename(ca_in.name().clone()); Ok(ca) } @@ -105,7 +108,7 @@ where }) .collect_trusted() }; - ca.rename(ca_in.name()); + ca.rename(ca_in.name().clone()); Ok(ca) } @@ -198,7 +201,7 @@ fn is_in_string_list_categorical( .collect() } }; - ca.rename(ca_in.name()); + ca.rename(ca_in.name().clone()); Ok(ca) } @@ -267,7 +270,7 @@ fn is_in_binary_list(ca_in: &BinaryChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult PolarsResult PolarsResult polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } @@ -449,7 +452,7 @@ fn is_in_struct_list(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult = ca_in - .struct_fields() - .iter() - .map(|f| f.data_type()) - .collect(); - let other_dtypes: Vec<_> = other - .struct_fields() - .iter() - .map(|f| f.data_type()) - .collect(); + let ca_in_dtypes: Vec<_> = ca_in.struct_fields().iter().map(|f| f.dtype()).collect(); + let other_dtypes: Vec<_> = other.struct_fields().iter().map(|f| f.dtype()).collect(); if ca_in_dtypes != other_dtypes { - let ca_in_names = ca_in.struct_fields().iter().map(|f| f.name()); - let other_names = other.struct_fields().iter().map(|f| f.name()); + let ca_in_names = ca_in.struct_fields().iter().map(|f| f.name().clone()); + let other_names = other.struct_fields().iter().map(|f| f.name().clone()); let supertypes = ca_in_dtypes .iter() .zip(other_dtypes.iter()) @@ -570,7 +565,10 @@ fn is_in_string_categorical( ) -> PolarsResult { // In case of fast unique, we can directly use the categories. Otherwise we need to // first get the unique physicals - let categories = StringChunked::with_chunk("", other.get_rev_map().get_categories().clone()); + let categories = StringChunked::with_chunk( + PlSmallStr::EMPTY, + other.get_rev_map().get_categories().clone(), + ); let other = if other._can_fast_unique() { categories } else { @@ -624,7 +622,7 @@ fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult { // fast path. if s.len() == 0 { - return Ok(BooleanChunked::full_null(s.name(), 0)); + return Ok(BooleanChunked::full_null(s.name().clone(), 0)); } else if s.len() == 1 { - return Ok(BooleanChunked::new(s.name(), &[true])); + return Ok(BooleanChunked::new(s.name().clone(), &[true])); } let s = s.to_physical_repr(); @@ -107,7 +107,7 @@ fn is_last_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); - BooleanChunked::with_chunk(ca.name(), arr) + BooleanChunked::with_chunk(ca.name().clone(), arr) } fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { @@ -120,7 +120,7 @@ fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { .map(|opt_v| unique.insert(opt_v)) .collect_reversed::>() .into_inner(); - new_ca.rename(ca.name()); + new_ca.rename(ca.name().clone()); new_ca } @@ -139,7 +139,7 @@ where .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_reversed::>() .into_inner(); - new_ca.rename(ca.name()); + new_ca.rename(ca.name().clone()); new_ca } @@ -157,7 +157,7 @@ fn is_last_distinct_struct(s: &Series) -> PolarsResult { } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); - Ok(BooleanChunked::with_chunk(s.name(), arr)) + Ok(BooleanChunked::with_chunk(s.name().clone(), arr)) } fn is_last_distinct_list(ca: &ListChunked) -> PolarsResult { @@ -173,5 +173,5 @@ fn is_last_distinct_list(ca: &ListChunked) -> PolarsResult { } let arr = BooleanArray::new(ArrowDataType::Boolean, out.into(), None); - Ok(BooleanChunked::with_chunk(ca.name(), arr)) + Ok(BooleanChunked::with_chunk(ca.name().clone(), arr)) } diff --git a/crates/polars-ops/src/series/ops/is_unique.rs b/crates/polars-ops/src/series/ops/is_unique.rs index 265e8736b35e..2f1d3de652ba 100644 --- a/crates/polars-ops/src/series/ops/is_unique.rs +++ b/crates/polars-ops/src/series/ops/is_unique.rs @@ -36,7 +36,7 @@ where unsafe { values.set_unchecked(idx as usize, setter) } } let arr = BooleanArray::from_data_default(values.into(), None); - BooleanChunked::with_chunk(ca.name(), arr) + BooleanChunked::with_chunk(ca.name().clone(), arr) } fn dispatcher(s: &Series, invert: bool) -> PolarsResult { @@ -75,9 +75,9 @@ fn dispatcher(s: &Series, invert: bool) -> PolarsResult { }; }, Null => match s.len() { - 0 => BooleanChunked::new(s.name(), [] as [bool; 0]), - 1 => BooleanChunked::new(s.name(), [!invert]), - len => BooleanChunked::full(s.name(), invert, len), + 0 => BooleanChunked::new(s.name().clone(), [] as [bool; 0]), + 1 => BooleanChunked::new(s.name().clone(), [!invert]), + len => BooleanChunked::full(s.name().clone(), invert, len), }, dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { diff --git a/crates/polars-ops/src/series/ops/log.rs b/crates/polars-ops/src/series/ops/log.rs index 7b914071de80..e0c870321a47 100644 --- a/crates/polars-ops/src/series/ops/log.rs +++ b/crates/polars-ops/src/series/ops/log.rs @@ -92,7 +92,7 @@ pub trait LogSeries: SeriesSealed { let pk = s.as_ref(); let pk = if normalize { - let sum = pk.sum_reduce().unwrap().into_series(""); + let sum = pk.sum_reduce().unwrap().into_series(PlSmallStr::EMPTY); if sum.get(0).unwrap().extract::().unwrap() != 1.0 { (pk / &sum)? diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 75c40c6d500d..ed4a446f3cca 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -136,6 +136,11 @@ pub use to_dummies::*; pub use unique::*; pub use various::*; mod not; + +#[cfg(feature = "dtype-duration")] +pub(crate) mod duration; +#[cfg(feature = "dtype-duration")] +pub use duration::*; pub use not::*; pub trait SeriesSealed { diff --git a/crates/polars-ops/src/series/ops/moment.rs b/crates/polars-ops/src/series/ops/moment.rs index be20c8ae981e..3a51d39213de 100644 --- a/crates/polars-ops/src/series/ops/moment.rs +++ b/crates/polars-ops/src/series/ops/moment.rs @@ -129,7 +129,7 @@ mod test { #[test] fn test_moment_compute() -> PolarsResult<()> { - let s = Series::new("", &[1, 2, 3, 4, 5, 23]); + let s = Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 4, 5, 23]); assert_eq!(moment(&s, 0)?, Some(1.0)); assert_eq!(moment(&s, 1)?, Some(0.0)); @@ -141,8 +141,11 @@ mod test { #[test] fn test_skew() -> PolarsResult<()> { - let s = Series::new("", &[1, 2, 3, 4, 5, 23]); - let s2 = Series::new("", &[Some(1), Some(2), Some(3), None, Some(1)]); + let s = Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 4, 5, 23]); + let s2 = Series::new( + PlSmallStr::EMPTY, + &[Some(1), Some(2), Some(3), None, Some(1)], + ); assert!((s.skew(false)?.unwrap() - 2.2905330058490514).abs() < 0.0001); assert!((s.skew(true)?.unwrap() - 1.6727687946848508).abs() < 0.0001); @@ -155,7 +158,7 @@ mod test { #[test] fn test_kurtosis() -> PolarsResult<()> { - let s = Series::new("", &[1, 2, 3, 4, 5, 23]); + let s = Series::new(PlSmallStr::EMPTY, &[1, 2, 3, 4, 5, 23]); assert!((s.kurtosis(true, true)?.unwrap() - 0.9945668771797536).abs() < 0.0001); assert!((s.kurtosis(true, false)?.unwrap() - 5.400820058440946).abs() < 0.0001); @@ -163,7 +166,7 @@ mod test { assert!((s.kurtosis(false, false)?.unwrap() - 8.400820058440946).abs() < 0.0001); let s2 = Series::new( - "", + PlSmallStr::EMPTY, &[Some(1), Some(2), Some(3), None, Some(1), Some(2), Some(3)], ); assert!((s2.kurtosis(true, true)?.unwrap() - (-1.5)).abs() < 0.0001); diff --git a/crates/polars-ops/src/series/ops/pct_change.rs b/crates/polars-ops/src/series/ops/pct_change.rs index 56c7af142e9b..9cb45dac1d6f 100644 --- a/crates/polars-ops/src/series/ops/pct_change.rs +++ b/crates/polars-ops/src/series/ops/pct_change.rs @@ -20,6 +20,6 @@ pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { if let Some(n) = n_s.i64()?.get(0) { diff(&fill_null_s, n, NullBehavior::Ignore)?.divide(&fill_null_s.shift(n)) } else { - Ok(Series::full_null(s.name(), s.len(), s.dtype())) + Ok(Series::full_null(s.name().clone(), s.len(), s.dtype())) } } diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index dd2fe3936945..4021443a3534 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -65,17 +65,26 @@ unsafe fn rank_impl(idxs: &IdxCa, neq: &BooleanArray, fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> Series { let len = s.len(); let null_count = s.null_count(); + + if null_count == len { + let dt = match method { + Average => DataType::Float64, + _ => IDX_DTYPE, + }; + return Series::full_null(s.name().clone(), s.len(), &dt); + } + match len { 1 => { return match method { - Average => Series::new(s.name(), &[1.0f64]), - _ => Series::new(s.name(), &[1 as IdxSize]), + Average => Series::new(s.name().clone(), &[1.0f64]), + _ => Series::new(s.name().clone(), &[1 as IdxSize]), }; }, 0 => { return match method { - Average => Float64Chunked::from_slice(s.name(), &[]).into_series(), - _ => IdxCa::from_slice(s.name(), &[]).into_series(), + Average => Float64Chunked::from_slice(s.name().clone(), &[]).into_series(), + _ => IdxCa::from_slice(s.name().clone(), &[]).into_series(), }; }, _ => {}, @@ -83,8 +92,8 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> if null_count == len { return match method { - Average => Float64Chunked::full_null(s.name(), len).into_series(), - _ => IdxCa::full_null(s.name(), len).into_series(), + Average => Float64Chunked::full_null(s.name().clone(), len).into_series(), + _ => IdxCa::full_null(s.name().clone(), len).into_series(), }; } @@ -109,7 +118,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> rank += 1; } } - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() } else { let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) }; let not_consecutive_same = sorted_values @@ -132,7 +141,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> rank += 1; } }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() }, Average => unsafe { let mut out = vec![0.0; s.len()]; @@ -145,7 +154,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> *out.get_unchecked_mut(*i as usize) = avg; } }); - Float64Chunked::from_vec_validity(s.name(), out, validity).into_series() + Float64Chunked::from_vec_validity(s.name().clone(), out, validity).into_series() }, Min => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -155,7 +164,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> } rank += ties.len() as IdxSize; }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() }, Max => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -165,7 +174,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> *out.get_unchecked_mut(*i as usize) = rank - 1; } }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() }, Dense => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -175,7 +184,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> } rank += 1; }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name().clone(), out, validity).into_series() }, Ordinal => unreachable!(), } @@ -196,7 +205,7 @@ mod test { #[test] fn test_rank() -> PolarsResult<()> { - let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]); + let s = Series::new("a".into(), &[1, 2, 3, 2, 2, 3, 0]); let out = rank(&s, RankMethod::Ordinal, false, None) .idx()? @@ -244,7 +253,7 @@ mod test { assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]); let s = Series::new( - "a", + "a".into(), &[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)], ); @@ -266,7 +275,7 @@ mod test { ] ); let s = Series::new( - "a", + "a".into(), &[ Some(5), Some(6), @@ -301,7 +310,7 @@ mod test { #[test] fn test_rank_all_null() -> PolarsResult<()> { - let s = UInt32Chunked::new("", &[None, None, None]).into_series(); + let s = UInt32Chunked::new("".into(), &[None, None, None]).into_series(); let out = rank(&s, RankMethod::Average, false, None) .f64()? .into_iter() @@ -317,7 +326,7 @@ mod test { #[test] fn test_rank_empty() { - let s = UInt32Chunked::from_slice("", &[]).into_series(); + let s = UInt32Chunked::from_slice("".into(), &[]).into_series(); let out = rank(&s, RankMethod::Average, false, None); assert_eq!(out.dtype(), &DataType::Float64); let out = rank(&s, RankMethod::Max, false, None); @@ -326,7 +335,7 @@ mod test { #[test] fn test_rank_reverse() -> PolarsResult<()> { - let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]); + let s = Series::new("".into(), &[None, Some(1), Some(1), Some(5), None]); let out = rank(&s, RankMethod::Dense, true, None) .idx()? .into_iter() diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index a331078318ea..ff9f8f18760d 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -138,7 +138,7 @@ fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsRes // Transfer validity from `mask` to `out`. if mask.null_count() > 0 { - out = out.zip_with(&mask, &Series::new_null("", s.len()))? + out = out.zip_with(&mask, &Series::new_null(PlSmallStr::EMPTY, s.len()))? } Ok(out) } @@ -169,7 +169,7 @@ fn replace_by_multiple( let joined = df.join( &replacer, - [s.name()], + [s.name().as_str()], ["__POLARS_REPLACE_OLD"], JoinArgs { how: JoinType::Left, @@ -207,7 +207,7 @@ fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsRes let joined = df.join( &replacer, - [s.name()], + [s.name().as_str()], ["__POLARS_REPLACE_OLD"], JoinArgs { how: JoinType::Left, @@ -231,11 +231,12 @@ fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsRes // Build replacer dataframe. fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult { - old.rename("__POLARS_REPLACE_OLD"); - new.rename("__POLARS_REPLACE_NEW"); + old.rename(PlSmallStr::from_static("__POLARS_REPLACE_OLD")); + new.rename(PlSmallStr::from_static("__POLARS_REPLACE_NEW")); let cols = if add_mask { - let mask = Series::new("__POLARS_REPLACE_MASK", &[true]).new_from_index(0, new.len()); + let mask = Series::new(PlSmallStr::from_static("__POLARS_REPLACE_MASK"), &[true]) + .new_from_index(0, new.len()); vec![old, new, mask] } else { vec![old, new] diff --git a/crates/polars-ops/src/series/ops/rle.rs b/crates/polars-ops/src/series/ops/rle.rs index 671f5c561f39..8659512673f1 100644 --- a/crates/polars-ops/src/series/ops/rle.rs +++ b/crates/polars-ops/src/series/ops/rle.rs @@ -9,7 +9,7 @@ pub fn rle(s: &Series) -> PolarsResult { let mut lengths = Vec::::with_capacity(n_runs as usize); lengths.push(1); - let mut vals = Series::new_empty("value", s.dtype()); + let mut vals = Series::new_empty(PlSmallStr::from_static("value"), s.dtype()); let vals = vals.extend(&s.head(Some(1)))?.extend(&s2.filter(&s_neq)?)?; let mut idx = 0; @@ -25,14 +25,17 @@ pub fn rle(s: &Series) -> PolarsResult { } } - let outvals = vec![Series::from_vec("len", lengths), vals.to_owned()]; - Ok(StructChunked::from_series(s.name(), &outvals)?.into_series()) + let outvals = vec![ + Series::from_vec(PlSmallStr::from_static("len"), lengths), + vals.to_owned(), + ]; + Ok(StructChunked::from_series(s.name().clone(), &outvals)?.into_series()) } /// Similar to `rle`, but maps values to run IDs. pub fn rle_id(s: &Series) -> PolarsResult { if s.len() == 0 { - return Ok(Series::new_empty(s.name(), &IDX_DTYPE)); + return Ok(Series::new_empty(s.name().clone(), &IDX_DTYPE)); } let (s1, s2) = (s.slice(0, s.len() - 1), s.slice(1, s.len())); let s_neq = s1.not_equal_missing(&s2)?; @@ -47,7 +50,7 @@ pub fn rle_id(s: &Series) -> PolarsResult { out.push(last); } } - Ok(IdxCa::from_vec(s.name(), out) + Ok(IdxCa::from_vec(s.name().clone(), out) .with_sorted_flag(IsSorted::Ascending) .into_series()) } diff --git a/crates/polars-ops/src/series/ops/round.rs b/crates/polars-ops/src/series/ops/round.rs index 2ee6c284d1b2..7ed6b2e40eed 100644 --- a/crates/polars-ops/src/series/ops/round.rs +++ b/crates/polars-ops/src/series/ops/round.rs @@ -101,7 +101,7 @@ mod test { #[test] fn test_round_series() { - let series = Series::new("a", &[1.003, 2.23222, 3.4352]); + let series = Series::new("a".into(), &[1.003, 2.23222, 3.4352]); let out = series.round(2).unwrap(); let ca = out.f64().unwrap(); assert_eq!(ca.get(0), Some(1.0)); diff --git a/crates/polars-ops/src/series/ops/search_sorted.rs b/crates/polars-ops/src/series/ops/search_sorted.rs index 78d5ba7eb134..11e97ef489e8 100644 --- a/crates/polars-ops/src/series/ops/search_sorted.rs +++ b/crates/polars-ops/src/series/ops/search_sorted.rs @@ -19,7 +19,29 @@ pub fn search_sorted( let search_values = search_values.str()?; let search_values = search_values.as_binary(); let idx = binary_search_ca(&ca, search_values.iter(), side, descending); - Ok(IdxCa::new_vec(s.name(), idx)) + Ok(IdxCa::new_vec(s.name().clone(), idx)) + }, + DataType::Boolean => { + let ca = s.bool().unwrap(); + let search_values = search_values.bool()?; + + let mut none_pos = None; + let mut false_pos = None; + let mut true_pos = None; + let idxs = search_values + .iter() + .map(|v| { + let cache = match v { + None => &mut none_pos, + Some(false) => &mut false_pos, + Some(true) => &mut true_pos, + }; + *cache.get_or_insert_with(|| { + binary_search_ca(ca, [v].into_iter(), side, descending)[0] + }) + }) + .collect(); + Ok(IdxCa::new_vec(s.name().clone(), idxs)) }, DataType::Binary => { let ca = s.binary().unwrap(); @@ -36,7 +58,7 @@ pub fn search_sorted( _ => unreachable!(), }; - Ok(IdxCa::new_vec(s.name(), idx)) + Ok(IdxCa::new_vec(s.name().clone(), idx)) }, dt if dt.is_numeric() => { let search_values = search_values.to_physical_repr(); @@ -46,7 +68,7 @@ pub fn search_sorted( let search_values: &ChunkedArray<$T> = search_values.as_ref().as_ref().as_ref(); binary_search_ca(ca, search_values.iter(), side, descending) }); - Ok(IdxCa::new_vec(s.name(), idx)) + Ok(IdxCa::new_vec(s.name().clone(), idx)) }, _ => polars_bail!(opq = search_sorted, original_dtype), } diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index f2d8c4f3b70a..3cd9d426ac1d 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -1,3 +1,5 @@ +use polars_utils::format_pl_smallstr; + use super::*; #[cfg(feature = "dtype-u8")] @@ -28,18 +30,16 @@ impl ToDummies for Series { // strings are formatted with extra \" \" in polars, so we // extract the string let name = if let Some(s) = av.get_str() { - format!("{col_name}{sep}{s}") + format_pl_smallstr!("{col_name}{sep}{s}") } else { // other types don't have this formatting issue - format!("{col_name}{sep}{av}") + format_pl_smallstr!("{col_name}{sep}{av}") }; let ca = match group { - GroupsIndicator::Idx((_, group)) => { - dummies_helper_idx(group, self.len(), &name) - }, + GroupsIndicator::Idx((_, group)) => dummies_helper_idx(group, self.len(), name), GroupsIndicator::Slice([offset, len]) => { - dummies_helper_slice(offset, len, self.len(), &name) + dummies_helper_slice(offset, len, self.len(), name) }, }; ca.into_series() @@ -50,7 +50,7 @@ impl ToDummies for Series { } } -fn dummies_helper_idx(groups: &[IdxSize], len: usize, name: &str) -> DummyCa { +fn dummies_helper_idx(groups: &[IdxSize], len: usize, name: PlSmallStr) -> DummyCa { let mut av = vec![0 as DummyType; len]; for &idx in groups { @@ -65,7 +65,7 @@ fn dummies_helper_slice( group_offset: IdxSize, group_len: IdxSize, len: usize, - name: &str, + name: PlSmallStr, ) -> DummyCa { let mut av = vec![0 as DummyType; len]; diff --git a/crates/polars-ops/src/series/ops/unique.rs b/crates/polars-ops/src/series/ops/unique.rs index 3a2d9b5652fe..e48509b1ce73 100644 --- a/crates/polars-ops/src/series/ops/unique.rs +++ b/crates/polars-ops/src/series/ops/unique.rs @@ -41,9 +41,9 @@ pub fn unique_counts(s: &Series) -> PolarsResult { }, DataType::Null => { let ca = if s.is_empty() { - IdxCa::new(s.name(), [] as [IdxSize; 0]) + IdxCa::new(s.name().clone(), [] as [IdxSize; 0]) } else { - IdxCa::new(s.name(), [s.len() as IdxSize]) + IdxCa::new(s.name().clone(), [s.len() as IdxSize]) }; Ok(ca.into_series()) }, diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index 1cb2f0d708a8..9ad21ab617d3 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -16,18 +16,19 @@ pub trait SeriesMethods: SeriesSealed { &self, sort: bool, parallel: bool, - name: String, + name: PlSmallStr, normalize: bool, ) -> PolarsResult { let s = self.as_series(); polars_ensure!( - s.name() != name, - Duplicate: "using `value_counts` on a column/series named '{}' would lead to duplicate column names; change `name` to fix", name, + s.name() != &name, + Duplicate: "using `value_counts` on a column/series named '{}' would lead to duplicate \ + column names; change `name` to fix", name, ); // we need to sort here as well in case of `maintain_order` because duplicates behavior is undefined let groups = s.group_tuples(parallel, sort)?; let values = unsafe { s.agg_first(&groups) }; - let counts = groups.group_count().with_name(name.as_str()); + let counts = groups.group_count().with_name(name.clone()); let counts = if normalize { let len = s.len() as f64; @@ -53,7 +54,7 @@ pub trait SeriesMethods: SeriesSealed { } #[cfg(feature = "hash")] - fn hash(&self, build_hasher: ahash::RandomState) -> UInt64Chunked { + fn hash(&self, build_hasher: PlRandomState) -> UInt64Chunked { let s = self.as_series().to_physical_repr(); match s.dtype() { DataType::List(_) => { @@ -63,7 +64,7 @@ pub trait SeriesMethods: SeriesSealed { _ => { let mut h = vec![]; s.0.vec_hash(build_hasher, &mut h).unwrap(); - UInt64Chunked::from_vec(s.name(), h) + UInt64Chunked::from_vec(s.name().clone(), h) }, } } @@ -93,7 +94,7 @@ pub trait SeriesMethods: SeriesSealed { #[cfg(feature = "dtype-struct")] if matches!(s.dtype(), DataType::Struct(_)) { let encoded = _get_rows_encoded_ca( - "", + PlSmallStr::EMPTY, &[s.clone()], &[options.descending], &[options.nulls_last], diff --git a/crates/polars-parquet/Cargo.toml b/crates/polars-parquet/Cargo.toml index 5c62479ccaa3..26a57b22e713 100644 --- a/crates/polars-parquet/Cargo.toml +++ b/crates/polars-parquet/Cargo.toml @@ -20,6 +20,7 @@ bytemuck = { workspace = true } ethnum = { workspace = true } fallible-streaming-iterator = { workspace = true, optional = true } futures = { workspace = true, optional = true } +hashbrown = { workspace = true } num-traits = { workspace = true } polars-compute = { workspace = true } polars-error = { workspace = true } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs deleted file mode 100644 index 4c17b7bd2982..000000000000 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs +++ /dev/null @@ -1,326 +0,0 @@ -use std::default::Default; -use std::sync::atomic::{AtomicBool, Ordering}; - -use arrow::array::specification::try_check_utf8; -use arrow::array::{Array, BinaryArray, DictionaryArray, DictionaryKey, PrimitiveArray, Utf8Array}; -use arrow::bitmap::MutableBitmap; -use arrow::datatypes::{ArrowDataType, PhysicalType}; -use arrow::offset::Offset; - -use super::super::utils; -use super::super::utils::extend_from_decoder; -use super::decoders::*; -use super::utils::*; -use crate::parquet::encoding::hybrid_rle::gatherer::HybridRleGatherer; -use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; -use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::page::{DataPage, DictPage}; -use crate::read::deserialize::utils::{Decoder, GatheredHybridRle, StateTranslation}; -use crate::read::PrimitiveLogicalType; - -impl utils::ExactSize for (Binary, MutableBitmap) { - fn len(&self) -> usize { - self.0.len() - } -} - -impl<'a, O: Offset> StateTranslation<'a, BinaryDecoder> for BinaryStateTranslation<'a> { - type PlainDecoder = BinaryIter<'a>; - - fn new( - decoder: &BinaryDecoder, - page: &'a DataPage, - dict: Option<&'a as utils::Decoder>::Dict>, - page_validity: Option<&utils::PageValidity<'a>>, - ) -> ParquetResult { - let is_string = matches!( - page.descriptor.primitive_type.logical_type, - Some(PrimitiveLogicalType::String) - ); - decoder.check_utf8.store(is_string, Ordering::Relaxed); - BinaryStateTranslation::new(page, dict, page_validity, is_string) - } - - fn len_when_not_nullable(&self) -> usize { - BinaryStateTranslation::len_when_not_nullable(self) - } - - fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { - BinaryStateTranslation::skip_in_place(self, n) - } - - fn extend_from_state( - &mut self, - decoder: &mut BinaryDecoder, - decoded: &mut as utils::Decoder>::DecodedState, - page_validity: &mut Option>, - additional: usize, - ) -> ParquetResult<()> { - let len_before = decoded.0.offsets.len(); - - use BinaryStateTranslation as T; - match self { - T::Plain(page_values) => decoder.decode_plain_encoded( - decoded, - page_values, - page_validity.as_mut(), - additional, - )?, - T::Dictionary(page) => decoder.decode_dictionary_encoded( - decoded, - &mut page.values, - page_validity.as_mut(), - page.dict, - additional, - )?, - T::Delta(page) => { - let (values, validity) = decoded; - - match page_validity { - None => values - .extend_lengths(page.lengths.by_ref().take(additional), &mut page.values), - Some(page_validity) => { - let Binary { - offsets, - values: values_, - } = values; - - let last_offset = *offsets.last(); - extend_from_decoder( - validity, - page_validity, - Some(additional), - offsets, - page.lengths.by_ref(), - )?; - - let length = *offsets.last() - last_offset; - - let (consumed, remaining) = page.values.split_at(length.to_usize()); - page.values = remaining; - values_.extend_from_slice(consumed); - }, - } - }, - T::DeltaBytes(page_values) => { - let (values, validity) = decoded; - - match page_validity { - None => { - for x in page_values.take(additional) { - values.push(x) - } - }, - Some(page_validity) => { - extend_from_decoder( - validity, - page_validity, - Some(additional), - values, - page_values, - )?; - }, - } - }, - } - - // @TODO: Double checking - if decoder.check_utf8.load(Ordering::Relaxed) { - // @TODO: This can report a better error. - let offsets = &decoded.0.offsets.as_slice()[len_before..]; - try_check_utf8(offsets, &decoded.0.values) - .map_err(|_| ParquetError::oos("invalid utf-8"))?; - } - - Ok(()) - } -} - -#[derive(Debug, Default)] -pub(crate) struct BinaryDecoder { - phantom_o: std::marker::PhantomData, - check_utf8: AtomicBool, -} - -impl utils::ExactSize for BinaryDict { - fn len(&self) -> usize { - BinaryDict::len(self) - } -} - -impl utils::Decoder for BinaryDecoder { - type Translation<'a> = BinaryStateTranslation<'a>; - type Dict = BinaryDict; - type DecodedState = (Binary, MutableBitmap); - type Output = Box; - - fn with_capacity(&self, capacity: usize) -> Self::DecodedState { - ( - Binary::::with_capacity(capacity), - MutableBitmap::with_capacity(capacity), - ) - } - - fn deserialize_dict(&self, page: DictPage) -> Self::Dict { - deserialize_plain(&page.buffer, page.num_values) - } - - fn decode_plain_encoded<'a>( - &mut self, - (values, validity): &mut Self::DecodedState, - page_values: &mut as StateTranslation<'a, Self>>::PlainDecoder, - page_validity: Option<&mut utils::PageValidity<'a>>, - limit: usize, - ) -> ParquetResult<()> { - let len_before = values.offsets.len(); - - match page_validity { - None => { - for x in page_values.by_ref().take(limit) { - values.push(x); - } - }, - Some(page_validity) => { - extend_from_decoder(validity, page_validity, Some(limit), values, page_values)? - }, - } - - if self.check_utf8.load(Ordering::Relaxed) { - // @TODO: This can report a better error. - let offsets = &values.offsets.as_slice()[len_before..]; - try_check_utf8(offsets, &values.values).map_err(|_| ParquetError::oos("invalid utf-8")) - } else { - Ok(()) - } - } - - fn decode_dictionary_encoded<'a>( - &mut self, - (values, validity): &mut Self::DecodedState, - page_values: &mut HybridRleDecoder<'a>, - page_validity: Option<&mut utils::PageValidity<'a>>, - dict: &Self::Dict, - limit: usize, - ) -> ParquetResult<()> { - struct BinaryGatherer<'a, O> { - dict: &'a BinaryDict, - _pd: std::marker::PhantomData, - } - - impl<'a, O: Offset> HybridRleGatherer<&'a [u8]> for BinaryGatherer<'a, O> { - type Target = Binary; - - fn target_reserve(&self, target: &mut Self::Target, n: usize) { - // @NOTE: This is an estimation for the reservation. It will probably not be - // accurate, but then it is a lot better than not allocating. - target.offsets.reserve(n); - target.values.reserve(n); - } - - fn target_num_elements(&self, target: &Self::Target) -> usize { - target.offsets.len_proxy() - } - - fn hybridrle_to_target(&self, value: u32) -> ParquetResult<&'a [u8]> { - let value = value as usize; - - if value >= self.dict.len() { - return Err(ParquetError::oos("Binary dictionary index out-of-range")); - } - - Ok(self.dict.value(value)) - } - - fn gather_one(&self, target: &mut Self::Target, value: &'a [u8]) -> ParquetResult<()> { - target.push(value); - Ok(()) - } - - fn gather_repeated( - &self, - target: &mut Self::Target, - value: &'a [u8], - n: usize, - ) -> ParquetResult<()> { - for _ in 0..n { - target.push(value); - } - Ok(()) - } - } - - let gatherer = BinaryGatherer { - dict, - _pd: std::marker::PhantomData, - }; - - match page_validity { - None => { - page_values.gather_n_into(values, limit, &gatherer)?; - }, - Some(page_validity) => { - let collector = GatheredHybridRle::new(page_values, &gatherer, &[]); - - extend_from_decoder(validity, page_validity, Some(limit), values, collector)?; - }, - } - - Ok(()) - } - - fn finalize( - &self, - data_type: ArrowDataType, - _dict: Option, - (values, validity): Self::DecodedState, - ) -> ParquetResult> { - super::finalize(data_type, values, validity) - } -} - -impl utils::DictDecodable for BinaryDecoder { - fn finalize_dict_array( - &self, - data_type: ArrowDataType, - dict: Self::Dict, - keys: PrimitiveArray, - ) -> ParquetResult> { - let value_data_type = match data_type.clone() { - ArrowDataType::Dictionary(_, values, _) => *values, - v => v, - }; - - let (_, dict_offsets, dict_values, _) = dict.into_inner(); - let dict = match value_data_type.to_physical_type() { - PhysicalType::Utf8 | PhysicalType::LargeUtf8 => { - Utf8Array::new(value_data_type, dict_offsets, dict_values, None).boxed() - }, - PhysicalType::Binary | PhysicalType::LargeBinary => { - BinaryArray::new(value_data_type, dict_offsets, dict_values, None).boxed() - }, - _ => unreachable!(), - }; - - // @TODO: Is this datatype correct? - Ok(DictionaryArray::try_new(data_type, keys, dict).unwrap()) - } -} - -impl utils::NestedDecoder for BinaryDecoder { - fn validity_extend( - _: &mut utils::State<'_, Self>, - (_, validity): &mut Self::DecodedState, - value: bool, - n: usize, - ) { - validity.extend_constant(n, value); - } - - fn values_extend_nulls( - _: &mut utils::State<'_, Self>, - (values, _): &mut Self::DecodedState, - n: usize, - ) { - values.extend_constant(n); - } -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/decoders.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/decoders.rs deleted file mode 100644 index 27ee2a9a251c..000000000000 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/decoders.rs +++ /dev/null @@ -1,213 +0,0 @@ -use arrow::array::specification::try_check_utf8; -use arrow::array::{BinaryArray, MutableBinaryValuesArray}; -use polars_error::PolarsResult; - -use super::super::utils; -use super::utils::*; -use crate::parquet::encoding::{delta_bitpacked, delta_length_byte_array, hybrid_rle, Encoding}; -use crate::parquet::error::ParquetResult; -use crate::parquet::page::{split_buffer, DataPage}; -use crate::read::deserialize::utils::PageValidity; - -pub(crate) type BinaryDict = BinaryArray; - -#[derive(Debug)] -pub(crate) struct Delta<'a> { - pub lengths: std::vec::IntoIter, - pub values: &'a [u8], -} - -impl<'a> Delta<'a> { - pub fn try_new(page: &'a DataPage) -> PolarsResult { - let values = split_buffer(page)?.values; - - let mut lengths_iter = delta_length_byte_array::Decoder::try_new(values)?; - - #[allow(clippy::needless_collect)] // we need to consume it to get the values - let lengths = lengths_iter - .by_ref() - .map(|x| x.map(|x| x as usize)) - .collect::>>()?; - - let values = lengths_iter.into_values(); - Ok(Self { - lengths: lengths.into_iter(), - values, - }) - } - - pub fn len(&self) -> usize { - self.lengths.size_hint().0 - } -} - -impl<'a> Iterator for Delta<'a> { - type Item = &'a [u8]; - - #[inline] - fn next(&mut self) -> Option { - let length = self.lengths.next()?; - let (item, remaining) = self.values.split_at(length); - self.values = remaining; - Some(item) - } - - fn size_hint(&self) -> (usize, Option) { - self.lengths.size_hint() - } -} - -#[derive(Debug)] -pub(crate) struct DeltaBytes<'a> { - prefix: std::vec::IntoIter, - suffix: std::vec::IntoIter, - data: &'a [u8], - data_offset: usize, - last_value: Vec, -} - -impl<'a> DeltaBytes<'a> { - pub fn try_new(page: &'a DataPage) -> PolarsResult { - let values = split_buffer(page)?.values; - let mut decoder = delta_bitpacked::Decoder::try_new(values)?; - let prefix = (&mut decoder) - .take(page.num_values()) - .map(|r| r.map(|v| v as i32).unwrap()) - .collect::>(); - - let mut data_offset = decoder.consumed_bytes(); - let mut decoder = delta_bitpacked::Decoder::try_new(&values[decoder.consumed_bytes()..])?; - let suffix = (&mut decoder) - .map(|r| r.map(|v| v as i32).unwrap()) - .collect::>(); - data_offset += decoder.consumed_bytes(); - - Ok(Self { - prefix: prefix.into_iter(), - suffix: suffix.into_iter(), - data: values, - data_offset, - last_value: vec![], - }) - } -} - -impl<'a> Iterator for DeltaBytes<'a> { - type Item = &'a [u8]; - - #[inline] - fn next(&mut self) -> Option { - let prefix_len = self.prefix.next()? as usize; - let suffix_len = self.suffix.next()? as usize; - - self.last_value.truncate(prefix_len); - self.last_value - .extend_from_slice(&self.data[self.data_offset..self.data_offset + suffix_len]); - self.data_offset += suffix_len; - - // SAFETY: the consumer will only keep one value around per iteration. - // We need a different API for this to work with safe code. - let extend_lifetime = - unsafe { std::mem::transmute::<&[u8], &'a [u8]>(self.last_value.as_slice()) }; - Some(extend_lifetime) - } - - fn size_hint(&self) -> (usize, Option) { - self.prefix.size_hint() - } -} - -#[derive(Debug)] -pub(crate) struct ValuesDictionary<'a> { - pub values: hybrid_rle::HybridRleDecoder<'a>, - pub dict: &'a BinaryDict, -} - -impl<'a> ValuesDictionary<'a> { - pub fn try_new(page: &'a DataPage, dict: &'a BinaryDict) -> PolarsResult { - let values = utils::dict_indices_decoder(page)?; - - Ok(Self { dict, values }) - } - - #[inline] - pub fn len(&self) -> usize { - self.values.len() - } -} - -#[derive(Debug)] -pub(crate) enum BinaryStateTranslation<'a> { - Plain(BinaryIter<'a>), - Dictionary(ValuesDictionary<'a>), - Delta(Delta<'a>), - DeltaBytes(DeltaBytes<'a>), -} - -impl<'a> BinaryStateTranslation<'a> { - pub(crate) fn new( - page: &'a DataPage, - dict: Option<&'a BinaryDict>, - _page_validity: Option<&PageValidity<'a>>, - is_string: bool, - ) -> ParquetResult { - match (page.encoding(), dict) { - (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict)) => { - if is_string { - try_check_utf8(dict.offsets(), dict.values())?; - } - Ok(BinaryStateTranslation::Dictionary( - ValuesDictionary::try_new(page, dict)?, - )) - }, - (Encoding::Plain, _) => { - let values = split_buffer(page)?.values; - let values = BinaryIter::new(values, page.num_values()); - - Ok(BinaryStateTranslation::Plain(values)) - }, - (Encoding::DeltaLengthByteArray, _) => { - Ok(BinaryStateTranslation::Delta(Delta::try_new(page)?)) - }, - (Encoding::DeltaByteArray, _) => Ok(BinaryStateTranslation::DeltaBytes( - DeltaBytes::try_new(page)?, - )), - _ => Err(utils::not_implemented(page)), - } - } - pub(crate) fn len_when_not_nullable(&self) -> usize { - match self { - Self::Plain(v) => v.len_when_not_nullable(), - Self::Dictionary(v) => v.len(), - Self::Delta(v) => v.len(), - Self::DeltaBytes(v) => v.size_hint().0, - } - } - - pub(crate) fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { - if n == 0 { - return Ok(()); - } - - match self { - Self::Plain(t) => _ = t.by_ref().nth(n - 1), - Self::Dictionary(t) => t.values.skip_in_place(n)?, - Self::Delta(t) => _ = t.by_ref().nth(n - 1), - Self::DeltaBytes(t) => _ = t.by_ref().nth(n - 1), - } - - Ok(()) - } -} - -pub(crate) fn deserialize_plain(values: &[u8], num_values: usize) -> BinaryDict { - // Each value is prepended by the length which is 4 bytes. - let num_bytes = values.len() - 4 * num_values; - - let mut dict_values = MutableBinaryValuesArray::::with_capacities(num_values, num_bytes); - for v in BinaryIter::new(values, num_values) { - dict_values.push(v) - } - - dict_values.into() -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/mod.rs deleted file mode 100644 index ca36cc0107d3..000000000000 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/mod.rs +++ /dev/null @@ -1,45 +0,0 @@ -mod basic; -pub(super) mod decoders; -pub(super) mod utils; - -use arrow::array::{Array, BinaryArray, Utf8Array}; -use arrow::bitmap::MutableBitmap; -use arrow::datatypes::{ArrowDataType, PhysicalType}; -use arrow::types::Offset; -pub(crate) use basic::BinaryDecoder; - -use self::utils::Binary; -use super::utils::freeze_validity; -use super::ParquetResult; - -fn finalize( - data_type: ArrowDataType, - mut values: Binary, - validity: MutableBitmap, -) -> ParquetResult> { - values.offsets.shrink_to_fit(); - values.values.shrink_to_fit(); - let validity = freeze_validity(validity); - - match data_type.to_physical_type() { - PhysicalType::Binary | PhysicalType::LargeBinary => unsafe { - Ok(BinaryArray::::new_unchecked( - data_type, - values.offsets.into(), - values.values.into(), - validity, - ) - .boxed()) - }, - PhysicalType::Utf8 | PhysicalType::LargeUtf8 => unsafe { - Ok(Utf8Array::::new_unchecked( - data_type, - values.offsets.into(), - values.values.into(), - validity, - ) - .boxed()) - }, - _ => unreachable!(), - } -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs deleted file mode 100644 index 57e263d93587..000000000000 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs +++ /dev/null @@ -1,142 +0,0 @@ -use arrow::offset::{Offset, Offsets}; -use arrow::pushable::Pushable; - -/// [`Pushable`] for variable length binary data. -#[derive(Debug, Default)] -pub struct Binary { - pub offsets: Offsets, - pub values: Vec, -} - -impl Binary { - #[inline] - pub fn with_capacity(capacity: usize) -> Self { - Self { - offsets: Offsets::with_capacity(capacity), - values: Vec::with_capacity(capacity.min(100) * 24), - } - } - - #[inline] - pub fn push(&mut self, v: &[u8]) { - if self.offsets.len_proxy() == 100 && self.offsets.capacity() > 100 { - let bytes_per_row = self.values.len() / 100 + 1; - let bytes_estimate = bytes_per_row * self.offsets.capacity(); - if bytes_estimate > self.values.capacity() { - self.values.reserve(bytes_estimate - self.values.capacity()); - } - } - - self.values.extend(v); - self.offsets.try_push(v.len()).unwrap() - } - - #[inline] - pub fn extend_constant(&mut self, additional: usize) { - self.offsets.extend_constant(additional); - } - - #[inline] - pub fn len(&self) -> usize { - self.offsets.len_proxy() - } - - #[inline] - pub fn extend_lengths>(&mut self, lengths: I, values: &mut &[u8]) { - let current_offset = *self.offsets.last(); - self.offsets.try_extend_from_lengths(lengths).unwrap(); - let new_offset = *self.offsets.last(); - let length = new_offset.to_usize() - current_offset.to_usize(); - let (consumed, remaining) = values.split_at(length); - *values = remaining; - self.values.extend_from_slice(consumed); - } -} - -impl<'a, O: Offset> Pushable<&'a [u8]> for Binary { - type Freeze = (); - #[inline] - fn reserve(&mut self, additional: usize) { - let avg_len = self.values.len() / std::cmp::max(self.offsets.last().to_usize(), 1); - self.values.reserve(additional * avg_len); - self.offsets.reserve(additional); - } - #[inline] - fn len(&self) -> usize { - self.len() - } - - #[inline] - fn push_null(&mut self) { - self.push(&[]) - } - - #[inline] - fn push(&mut self, value: &[u8]) { - self.push(value) - } - - #[inline] - fn extend_constant(&mut self, additional: usize, value: &[u8]) { - assert_eq!(value.len(), 0); - self.extend_constant(additional) - } - - #[inline] - fn extend_null_constant(&mut self, additional: usize) { - self.extend_constant(additional) - } - fn freeze(self) -> Self::Freeze { - unimplemented!() - } -} - -#[derive(Debug)] -pub struct BinaryIter<'a> { - values: &'a [u8], - - /// A maximum number of items that this [`BinaryIter`] may produce. - /// - /// This equal the length of the iterator i.f.f. the data encoded by the [`BinaryIter`] is not - /// nullable. - max_num_values: usize, -} - -impl<'a> BinaryIter<'a> { - pub fn new(values: &'a [u8], max_num_values: usize) -> Self { - Self { - values, - max_num_values, - } - } - - /// Return the length of the iterator when the data is not nullable. - pub fn len_when_not_nullable(&self) -> usize { - self.max_num_values - } -} - -impl<'a> Iterator for BinaryIter<'a> { - type Item = &'a [u8]; - - #[inline] - fn next(&mut self) -> Option { - if self.max_num_values == 0 { - assert!(self.values.is_empty()); - return None; - } - - let (length, remaining) = self.values.split_at(4); - let length: [u8; 4] = unsafe { length.try_into().unwrap_unchecked() }; - let length = u32::from_le_bytes(length) as usize; - let (result, remaining) = remaining.split_at(length); - self.max_num_values -= 1; - self.values = remaining; - Some(result) - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - (0, Some(self.max_num_values)) - } -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs index c9d2f6486017..6777f7e639c9 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview.rs @@ -1,3 +1,4 @@ +use std::mem::MaybeUninit; use std::sync::atomic::{AtomicBool, Ordering}; use arrow::array::{ @@ -5,52 +6,93 @@ use arrow::array::{ Utf8ViewArray, View, }; use arrow::bitmap::MutableBitmap; +use arrow::buffer::Buffer; use arrow::datatypes::{ArrowDataType, PhysicalType}; -use super::binary::decoders::*; -use super::utils::freeze_validity; -use crate::parquet::encoding::hybrid_rle::{self, DictionaryTranslator}; +use super::utils::{dict_indices_decoder, freeze_validity, BatchableCollector}; +use crate::parquet::encoding::delta_bitpacked::{lin_natural_sum, DeltaGatherer}; +use crate::parquet::encoding::hybrid_rle::gatherer::HybridRleGatherer; +use crate::parquet::encoding::{delta_byte_array, delta_length_byte_array, hybrid_rle, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::page::{DataPage, DictPage}; -use crate::read::deserialize::binary::utils::BinaryIter; -use crate::read::deserialize::utils::{ - self, binary_views_dict, extend_from_decoder, Decoder, PageValidity, StateTranslation, - TranslatedHybridRle, -}; +use crate::parquet::page::{split_buffer, DataPage, DictPage}; +use crate::read::deserialize::utils::{self, extend_from_decoder, Decoder, PageValidity}; use crate::read::PrimitiveLogicalType; type DecodedStateTuple = (MutableBinaryViewArray<[u8]>, MutableBitmap); -impl<'a> StateTranslation<'a, BinViewDecoder> for BinaryStateTranslation<'a> { +impl<'a> utils::StateTranslation<'a, BinViewDecoder> for StateTranslation<'a> { type PlainDecoder = BinaryIter<'a>; fn new( decoder: &BinViewDecoder, page: &'a DataPage, dict: Option<&'a ::Dict>, - page_validity: Option<&PageValidity<'a>>, + _page_validity: Option<&PageValidity<'a>>, ) -> ParquetResult { let is_string = matches!( page.descriptor.primitive_type.logical_type, Some(PrimitiveLogicalType::String) ); decoder.check_utf8.store(is_string, Ordering::Relaxed); - Self::new(page, dict, page_validity, is_string) + match (page.encoding(), dict) { + (Encoding::Plain, _) => { + let values = split_buffer(page)?.values; + let values = BinaryIter::new(values, page.num_values()); + + Ok(Self::Plain(values)) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = dict_indices_decoder(page)?; + Ok(Self::Dictionary(values)) + }, + (Encoding::DeltaLengthByteArray, _) => { + let values = split_buffer(page)?.values; + Ok(Self::DeltaLengthByteArray( + delta_length_byte_array::Decoder::try_new(values)?, + Vec::new(), + )) + }, + (Encoding::DeltaByteArray, _) => { + let values = split_buffer(page)?.values; + Ok(Self::DeltaBytes(delta_byte_array::Decoder::try_new( + values, + )?)) + }, + _ => Err(utils::not_implemented(page)), + } } fn len_when_not_nullable(&self) -> usize { - Self::len_when_not_nullable(self) + match self { + Self::Plain(v) => v.len_when_not_nullable(), + Self::Dictionary(v) => v.len(), + Self::DeltaLengthByteArray(v, _) => v.len(), + Self::DeltaBytes(v) => v.len(), + } } fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { - Self::skip_in_place(self, n) + if n == 0 { + return Ok(()); + } + + match self { + Self::Plain(t) => _ = t.by_ref().nth(n - 1), + Self::Dictionary(t) => t.skip_in_place(n)?, + Self::DeltaLengthByteArray(t, _) => t.skip_in_place(n)?, + Self::DeltaBytes(t) => t.skip_in_place(n)?, + } + + Ok(()) } fn extend_from_state( &mut self, decoder: &mut BinViewDecoder, decoded: &mut ::DecodedState, + is_optional: bool, page_validity: &mut Option>, + dict: Option<&'a ::Dict>, additional: usize, ) -> ParquetResult<()> { let views_offset = decoded.0.views().len(); @@ -63,6 +105,7 @@ impl<'a> StateTranslation<'a, BinViewDecoder> for BinaryStateTranslation<'a> { decoder.decode_plain_encoded( decoded, page_values, + is_optional, page_validity.as_mut(), additional, )?; @@ -70,43 +113,62 @@ impl<'a> StateTranslation<'a, BinViewDecoder> for BinaryStateTranslation<'a> { // Already done in decode_plain_encoded validate_utf8 = false; }, - Self::Dictionary(page) => { + Self::Dictionary(ref mut page) => { + let dict = dict.unwrap(); + decoder.decode_dictionary_encoded( decoded, - &mut page.values, + page, + is_optional, page_validity.as_mut(), - page.dict, + dict, additional, )?; // Already done in decode_plain_encoded validate_utf8 = false; }, - Self::Delta(page_values) => { + Self::DeltaLengthByteArray(ref mut page_values, ref mut lengths) => { let (values, validity) = decoded; + + let mut collector = DeltaCollector { + gatherer: &mut StatGatherer::default(), + pushed_lengths: lengths, + decoder: page_values, + }; + match page_validity { None => { - for value in page_values.by_ref().take(additional) { - values.push_value_ignore_validity(value) + (&mut collector).push_n(values, additional)?; + + if is_optional { + validity.extend_constant(additional, true); } }, - Some(page_validity) => { - extend_from_decoder( - validity, - page_validity, - Some(additional), - values, - page_values, - )?; - }, + Some(page_validity) => extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + &mut collector, + )?, } + + collector.flush(values); }, - Self::DeltaBytes(page_values) => { + Self::DeltaBytes(ref mut page_values) => { let (values, validity) = decoded; + + let mut collector = DeltaBytesCollector { + decoder: page_values, + }; + match page_validity { None => { - for x in page_values.take(additional) { - values.push_value_ignore_validity(x) + collector.push_n(values, additional)?; + + if is_optional { + validity.extend_constant(additional, true); } }, Some(page_validity) => extend_from_decoder( @@ -114,7 +176,7 @@ impl<'a> StateTranslation<'a, BinViewDecoder> for BinaryStateTranslation<'a> { page_validity, Some(additional), values, - page_values, + collector, )?, } }, @@ -134,7 +196,15 @@ impl<'a> StateTranslation<'a, BinViewDecoder> for BinaryStateTranslation<'a> { #[derive(Default)] pub(crate) struct BinViewDecoder { check_utf8: AtomicBool, - views_dict: Option>, +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum StateTranslation<'a> { + Plain(BinaryIter<'a>), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), + DeltaLengthByteArray(delta_length_byte_array::Decoder<'a>, Vec), + DeltaBytes(delta_byte_array::Decoder<'a>), } impl utils::ExactSize for DecodedStateTuple { @@ -143,9 +213,315 @@ impl utils::ExactSize for DecodedStateTuple { } } +impl utils::ExactSize for (Vec, Vec>) { + fn len(&self) -> usize { + self.0.len() + } +} + +pub(crate) struct DeltaCollector<'a, 'b> { + // We gatherer the decoded lengths into `pushed_lengths`. Then, we `flush` those to the + // `BinView` This allows us to group many memcopies into one and take better potential fast + // paths for inlineable views and such. + pub(crate) gatherer: &'b mut StatGatherer, + pub(crate) pushed_lengths: &'b mut Vec, + + pub(crate) decoder: &'b mut delta_length_byte_array::Decoder<'a>, +} + +pub(crate) struct DeltaBytesCollector<'a, 'b> { + pub(crate) decoder: &'b mut delta_byte_array::Decoder<'a>, +} + +/// A [`DeltaGatherer`] that gathers the minimum, maximum and summation of the values as `usize`s. +pub(crate) struct StatGatherer { + min: usize, + max: usize, + sum: usize, +} + +impl Default for StatGatherer { + fn default() -> Self { + Self { + min: usize::MAX, + max: usize::MIN, + sum: 0, + } + } +} + +impl DeltaGatherer for StatGatherer { + type Target = Vec; + + fn target_len(&self, target: &Self::Target) -> usize { + target.len() + } + + fn target_reserve(&self, target: &mut Self::Target, n: usize) { + target.reserve(n); + } + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + if v < 0 { + return Err(ParquetError::oos("DELTA_LENGTH_BYTE_ARRAY length < 0")); + } + + if v > i64::from(u32::MAX) { + return Err(ParquetError::not_supported( + "DELTA_LENGTH_BYTE_ARRAY length > u32::MAX", + )); + } + + let v = v as usize; + + self.min = self.min.min(v); + self.max = self.max.max(v); + self.sum += v; + + target.push(v as u32); + + Ok(()) + } + + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + let mut is_invalid = false; + let mut is_too_large = false; + + target.extend(slice.iter().map(|&v| { + is_invalid |= v < 0; + is_too_large |= v > i64::from(u32::MAX); + + let v = v as usize; + + self.min = self.min.min(v); + self.max = self.max.max(v); + self.sum += v; + + v as u32 + })); + + if is_invalid { + target.truncate(target.len() - slice.len()); + return Err(ParquetError::oos("DELTA_LENGTH_BYTE_ARRAY length < 0")); + } + + if is_too_large { + return Err(ParquetError::not_supported( + "DELTA_LENGTH_BYTE_ARRAY length > u32::MAX", + )); + } + + Ok(()) + } + + fn gather_constant( + &mut self, + target: &mut Self::Target, + v: i64, + delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + if v < 0 || (delta < 0 && num_repeats > 0 && (num_repeats - 1) as i64 * delta + v < 0) { + return Err(ParquetError::oos("DELTA_LENGTH_BYTE_ARRAY length < 0")); + } + + if v > i64::from(u32::MAX) || v + ((num_repeats - 1) as i64) * delta > i64::from(u32::MAX) { + return Err(ParquetError::not_supported( + "DELTA_LENGTH_BYTE_ARRAY length > u32::MAX", + )); + } + + target.extend((0..num_repeats).map(|i| (v + (i as i64) * delta) as u32)); + + let vstart = v; + let vend = v + (num_repeats - 1) as i64 * delta; + + let (min, max) = if delta < 0 { + (vend, vstart) + } else { + (vstart, vend) + }; + + let sum = lin_natural_sum(v, delta, num_repeats) as usize; + + #[cfg(debug_assertions)] + { + assert_eq!( + (0..num_repeats) + .map(|i| (v + (i as i64) * delta) as usize) + .sum::(), + sum + ); + } + + self.min = self.min.min(min as usize); + self.max = self.max.max(max as usize); + self.sum += sum; + + Ok(()) + } +} + +impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for &mut DeltaCollector<'a, 'b> { + fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { + target.reserve(n); + } + + fn push_n( + &mut self, + _target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + self.decoder + .lengths + .gather_n_into(self.pushed_lengths, n, self.gatherer)?; + + Ok(()) + } + + fn push_n_nulls( + &mut self, + target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + self.flush(target); + target.extend_constant(n, >::None); + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) + } +} + +impl<'a, 'b> DeltaCollector<'a, 'b> { + pub fn flush(&mut self, target: &mut MutableBinaryViewArray<[u8]>) { + if !self.pushed_lengths.is_empty() { + let start_bytes_len = target.total_bytes_len(); + let start_buffer_len = target.total_buffer_len(); + unsafe { + target.extend_from_lengths_with_stats( + &self.decoder.values[self.decoder.offset..], + self.pushed_lengths.iter().map(|&v| v as usize), + self.gatherer.min, + self.gatherer.max, + self.gatherer.sum, + ) + }; + debug_assert_eq!( + target.total_bytes_len() - start_bytes_len, + self.gatherer.sum, + ); + debug_assert_eq!( + target.total_buffer_len() - start_buffer_len, + self.pushed_lengths + .iter() + .map(|&v| v as usize) + .filter(|&v| v > View::MAX_INLINE_SIZE as usize) + .sum::(), + ); + + self.decoder.offset += self.gatherer.sum; + self.pushed_lengths.clear(); + *self.gatherer = StatGatherer::default(); + } + } +} + +impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for DeltaBytesCollector<'a, 'b> { + fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { + target.reserve(n); + } + + fn push_n(&mut self, target: &mut MutableBinaryViewArray<[u8]>, n: usize) -> ParquetResult<()> { + struct MaybeUninitCollector(usize); + + impl DeltaGatherer for MaybeUninitCollector { + type Target = [MaybeUninit; BATCH_SIZE]; + + fn target_len(&self, _target: &Self::Target) -> usize { + self.0 + } + + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + target[self.0] = MaybeUninit::new(v as usize); + self.0 += 1; + Ok(()) + } + } + + let decoder_len = self.decoder.len(); + let mut n = usize::min(n, decoder_len); + + if n == 0 { + return Ok(()); + } + + let mut buffer = Vec::new(); + target.reserve(n); + + const BATCH_SIZE: usize = 4096; + + let mut prefix_lengths = [const { MaybeUninit::::uninit() }; BATCH_SIZE]; + let mut suffix_lengths = [const { MaybeUninit::::uninit() }; BATCH_SIZE]; + + while n > 0 { + let num_elems = usize::min(n, BATCH_SIZE); + n -= num_elems; + + self.decoder.prefix_lengths.gather_n_into( + &mut prefix_lengths, + num_elems, + &mut MaybeUninitCollector(0), + )?; + self.decoder.suffix_lengths.gather_n_into( + &mut suffix_lengths, + num_elems, + &mut MaybeUninitCollector(0), + )?; + + for i in 0..num_elems { + let prefix_length = unsafe { prefix_lengths[i].assume_init() }; + let suffix_length = unsafe { suffix_lengths[i].assume_init() }; + + buffer.clear(); + + buffer.extend_from_slice(&self.decoder.last[..prefix_length]); + buffer.extend_from_slice( + &self.decoder.values[self.decoder.offset..self.decoder.offset + suffix_length], + ); + + target.push_value(&buffer); + + self.decoder.last.clear(); + std::mem::swap(&mut self.decoder.last, &mut buffer); + + self.decoder.offset += suffix_length; + } + } + + Ok(()) + } + + fn push_n_nulls( + &mut self, + target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + target.extend_constant(n, >::None); + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) + } +} + impl utils::Decoder for BinViewDecoder { - type Translation<'a> = BinaryStateTranslation<'a>; - type Dict = BinaryDict; + type Translation<'a> = StateTranslation<'a>; + type Dict = (Vec, Vec>); type DecodedState = DecodedStateTuple; type Output = Box; @@ -156,36 +532,167 @@ impl utils::Decoder for BinViewDecoder { ) } - fn deserialize_dict(&self, page: DictPage) -> Self::Dict { - deserialize_plain(&page.buffer, page.num_values) + fn apply_dictionary( + &mut self, + (values, _): &mut Self::DecodedState, + dict: &Self::Dict, + ) -> ParquetResult<()> { + if values.completed_buffers().len() < dict.1.len() { + for buffer in &dict.1 { + values.push_buffer(buffer.clone()); + } + } + + assert!(values.completed_buffers().len() == dict.1.len()); + + Ok(()) + } + + fn deserialize_dict(&self, page: DictPage) -> ParquetResult { + let values = &page.buffer; + let num_values = page.num_values; + + // Each value is prepended by the length which is 4 bytes. + let num_bytes = values.len() - 4 * num_values; + + let mut views = Vec::with_capacity(num_values); + let mut buffer = Vec::with_capacity(num_bytes); + + let mut buffers = Vec::with_capacity(1); + + let mut offset = 0; + let mut max_length = 0; + views.extend(BinaryIter::new(values, num_values).map(|v| { + let length = v.len(); + max_length = usize::max(length, max_length); + if length <= View::MAX_INLINE_SIZE as usize { + View::new_inline(v) + } else { + if offset >= u32::MAX as usize { + let full_buffer = std::mem::take(&mut buffer); + let num_bytes = full_buffer.capacity() - full_buffer.len(); + buffers.push(Buffer::from(full_buffer)); + buffer.reserve(num_bytes); + offset = 0; + } + + buffer.extend_from_slice(v); + let view = View::new_from_bytes(v, buffers.len() as u32, offset as u32); + offset += v.len(); + view + } + })); + + buffers.push(Buffer::from(buffer)); + + if self.check_utf8.load(Ordering::Relaxed) { + // This is a small trick that allows us to check the Parquet buffer instead of the view + // buffer. Batching the UTF-8 verification is more performant. For this to be allowed, + // all the interleaved lengths need to be valid UTF-8. + // + // Every strings prepended by 4 bytes (L, 0, 0, 0), since we check here L < 128. L is + // only a valid first byte of a UTF-8 code-point and (L, 0, 0, 0) is valid UTF-8. + // Consequently, it is valid to just check the whole buffer. + if max_length < 128 { + simdutf8::basic::from_utf8(values) + .map_err(|_| ParquetError::oos("String data contained invalid UTF-8"))?; + } else { + arrow::array::validate_utf8_view(&views, &buffers) + .map_err(|_| ParquetError::oos("String data contained invalid UTF-8"))?; + } + } + + Ok((views, buffers)) } fn decode_plain_encoded<'a>( &mut self, (values, validity): &mut Self::DecodedState, - page_values: &mut as StateTranslation<'a, Self>>::PlainDecoder, + page_values: &mut as utils::StateTranslation<'a, Self>>::PlainDecoder, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, limit: usize, ) -> ParquetResult<()> { let views_offset = values.views().len(); let buffer_offset = values.completed_buffers().len(); + struct Collector<'a, 'b> { + iter: &'b mut BinaryIter<'a>, + max_length: &'b mut usize, + } + + impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'a, 'b> { + fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { + target.reserve(n); + } + + fn push_n( + &mut self, + target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + for x in self.iter.take(n) { + *self.max_length = usize::max(*self.max_length, x.len()); + target.push_value(x); + } + Ok(()) + } + + fn push_n_nulls( + &mut self, + target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + target.extend_constant(n, >::None); + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + if n > 0 { + _ = self.iter.nth(n - 1); + } + Ok(()) + } + } + + let mut max_length = 0; + let buffer = page_values.values; + let mut collector = Collector { + iter: page_values, + max_length: &mut max_length, + }; + match page_validity { None => { - for x in page_values.by_ref().take(limit) { - values.push_value_ignore_validity(x) + collector.push_n(values, limit)?; + + if is_optional { + validity.extend_constant(limit, true); } }, Some(page_validity) => { - extend_from_decoder(validity, page_validity, Some(limit), values, page_values)? + extend_from_decoder(validity, page_validity, Some(limit), values, collector)? }, } + let buffer = &buffer[..buffer.len() - page_values.values.len()]; + if self.check_utf8.load(Ordering::Relaxed) { - // @TODO: Better error message - values - .validate_utf8(buffer_offset, views_offset) - .map_err(|_| ParquetError::oos("Binary view contained invalid UTF-8"))? + // This is a small trick that allows us to check the Parquet buffer instead of the view + // buffer. Batching the UTF-8 verification is more performant. For this to be allowed, + // all the interleaved lengths need to be valid UTF-8. + // + // Every strings prepended by 4 bytes (L, 0, 0, 0), since we check here L < 128. L is + // only a valid first byte of a UTF-8 code-point and (L, 0, 0, 0) is valid UTF-8. + // Consequently, it is valid to just check the whole buffer. + if max_length < 128 { + simdutf8::basic::from_utf8(buffer) + .map_err(|_| ParquetError::oos("String data contained invalid UTF-8"))?; + } else { + values + .validate_utf8(buffer_offset, views_offset) + .map_err(|_| ParquetError::oos("String data contained invalid UTF-8"))? + } } Ok(()) @@ -195,32 +702,138 @@ impl utils::Decoder for BinViewDecoder { &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut hybrid_rle::HybridRleDecoder<'a>, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, dict: &Self::Dict, limit: usize, ) -> ParquetResult<()> { - let validate_utf8 = self.check_utf8.load(Ordering::Relaxed); + struct DictionaryTranslator<'a>(&'a [View]); + + impl<'a> HybridRleGatherer for DictionaryTranslator<'a> { + type Target = MutableBinaryViewArray<[u8]>; + + fn target_reserve(&self, target: &mut Self::Target, n: usize) { + target.reserve(n); + } + + fn target_num_elements(&self, target: &Self::Target) -> usize { + target.len() + } + + fn hybridrle_to_target(&self, value: u32) -> ParquetResult { + self.0 + .get(value as usize) + .cloned() + .ok_or(ParquetError::oos("Dictionary index is out of range")) + } + + fn gather_one(&self, target: &mut Self::Target, value: View) -> ParquetResult<()> { + // SAFETY: + // - All the dictionary values are already buffered + // - We keep the `total_bytes_len` in-sync with the views + unsafe { + target.views_mut().push(value); + target.set_total_bytes_len(target.total_bytes_len() + value.length as usize); + } - if validate_utf8 && simdutf8::basic::from_utf8(dict.values()).is_err() { - return Err(ParquetError::oos( - "Binary view dictionary contained invalid UTF-8", - )); + Ok(()) + } + + fn gather_repeated( + &self, + target: &mut Self::Target, + value: View, + n: usize, + ) -> ParquetResult<()> { + // SAFETY: + // - All the dictionary values are already buffered + // - We keep the `total_bytes_len` in-sync with the views + unsafe { + let length = target.views_mut().len(); + target.views_mut().resize(length + n, value); + target + .set_total_bytes_len(target.total_bytes_len() + n * value.length as usize); + } + + Ok(()) + } + + fn gather_slice(&self, target: &mut Self::Target, source: &[u32]) -> ParquetResult<()> { + let Some(source_max) = source.iter().copied().max() else { + return Ok(()); + }; + + if source_max as usize >= self.0.len() { + return Err(ParquetError::oos("Dictionary index is out of range")); + } + + let mut view_length_sum = 0usize; + // Safety: We have checked before that source only has indexes that are smaller than the + // dictionary length. + // + // Safety: + // - All the dictionary values are already buffered + // - We keep the `total_bytes_len` in-sync with the views + unsafe { + target.views_mut().extend(source.iter().map(|&src_idx| { + let v = *self.0.get_unchecked(src_idx as usize); + view_length_sum += v.length as usize; + v + })); + target.set_total_bytes_len(target.total_bytes_len() + view_length_sum); + } + + Ok(()) + } } - let views_dict = self - .views_dict - .get_or_insert_with(|| binary_views_dict(values, dict)); - let translator = DictionaryTranslator(views_dict); + let translator = DictionaryTranslator(&dict.0); match page_validity { None => { - page_values.translate_and_collect_n_into(values.views_mut(), limit, &translator)?; - if let Some(validity) = values.validity() { + page_values.gather_n_into(values, limit, &translator)?; + + if is_optional { validity.extend_constant(limit, true); } }, Some(page_validity) => { - let collector = TranslatedHybridRle::new(page_values, &translator); + struct Collector<'a, 'b> { + decoder: &'b mut hybrid_rle::HybridRleDecoder<'a>, + translator: DictionaryTranslator<'b>, + } + + impl<'a, 'b> BatchableCollector<(), MutableBinaryViewArray<[u8]>> for Collector<'a, 'b> { + fn reserve(target: &mut MutableBinaryViewArray<[u8]>, n: usize) { + target.reserve(n); + } + + fn push_n( + &mut self, + target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + self.decoder.gather_n_into(target, n, &self.translator)?; + Ok(()) + } + + fn push_n_nulls( + &mut self, + target: &mut MutableBinaryViewArray<[u8]>, + n: usize, + ) -> ParquetResult<()> { + target.extend_constant(n, >::None); + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) + } + } + let collector = Collector { + decoder: page_values, + translator, + }; extend_from_decoder(validity, page_validity, Some(limit), values, collector)?; }, } @@ -230,7 +843,7 @@ impl utils::Decoder for BinViewDecoder { fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, _dict: Option, (values, validity): Self::DecodedState, ) -> ParquetResult> { @@ -239,13 +852,13 @@ impl utils::Decoder for BinViewDecoder { let validity = freeze_validity(validity); array = array.with_validity(validity); - match data_type.to_physical_type() { + match dtype.to_physical_type() { PhysicalType::BinaryView => Ok(array.boxed()), PhysicalType::Utf8View => { // SAFETY: we already checked utf8 unsafe { Ok(Utf8ViewArray::new_unchecked( - data_type, + dtype, array.views().clone(), array.data_buffers().clone(), array.validity().cloned(), @@ -263,28 +876,30 @@ impl utils::Decoder for BinViewDecoder { impl utils::DictDecodable for BinViewDecoder { fn finalize_dict_array( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Self::Dict, keys: PrimitiveArray, ) -> ParquetResult> { - let value_data_type = match &data_type { + let value_dtype = match &dtype { ArrowDataType::Dictionary(_, values, _) => values.as_ref().clone(), - _ => data_type.clone(), + _ => dtype.clone(), }; - let mut view_dict = MutableBinaryViewArray::with_capacity(dict.len()); - for v in dict.iter() { - view_dict.push(v); + let mut view_dict = MutableBinaryViewArray::with_capacity(dict.0.len()); + for buffer in dict.1 { + view_dict.push_buffer(buffer); } + unsafe { view_dict.views_mut().extend(dict.0.iter()) }; + unsafe { view_dict.set_total_bytes_len(dict.0.iter().map(|v| v.length as usize).sum()) }; let view_dict = view_dict.freeze(); - let dict = match value_data_type.to_physical_type() { + let dict = match value_dtype.to_physical_type() { PhysicalType::Utf8View => view_dict.to_utf8view().unwrap().boxed(), PhysicalType::BinaryView => view_dict.boxed(), _ => unreachable!(), }; - Ok(DictionaryArray::try_new(data_type, keys, dict).unwrap()) + Ok(DictionaryArray::try_new(dtype, keys, dict).unwrap()) } } @@ -306,3 +921,53 @@ impl utils::NestedDecoder for BinViewDecoder { values.extend_constant(n, >::None); } } + +#[derive(Debug)] +pub struct BinaryIter<'a> { + values: &'a [u8], + + /// A maximum number of items that this [`BinaryIter`] may produce. + /// + /// This equal the length of the iterator i.f.f. the data encoded by the [`BinaryIter`] is not + /// nullable. + max_num_values: usize, +} + +impl<'a> BinaryIter<'a> { + pub fn new(values: &'a [u8], max_num_values: usize) -> Self { + Self { + values, + max_num_values, + } + } + + /// Return the length of the iterator when the data is not nullable. + pub fn len_when_not_nullable(&self) -> usize { + self.max_num_values + } +} + +impl<'a> Iterator for BinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + if self.max_num_values == 0 { + assert!(self.values.is_empty()); + return None; + } + + let (length, remaining) = self.values.split_at(4); + let length: [u8; 4] = unsafe { length.try_into().unwrap_unchecked() }; + let length = u32::from_le_bytes(length) as usize; + let (result, remaining) = remaining.split_at(length); + self.max_num_values -= 1; + self.values = remaining; + Some(result) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.max_num_values)) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs index 1f33da0678d6..af2e504d2646 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean.rs @@ -89,20 +89,29 @@ impl<'a> utils::StateTranslation<'a, BooleanDecoder> for StateTranslation<'a> { &mut self, decoder: &mut BooleanDecoder, decoded: &mut ::DecodedState, + is_optional: bool, page_validity: &mut Option>, + _: Option<&'a ::Dict>, additional: usize, ) -> ParquetResult<()> { match self { Self::Plain(page_values) => decoder.decode_plain_encoded( decoded, page_values, + is_optional, page_validity.as_mut(), additional, )?, Self::Rle(page_values) => { let (values, validity) = decoded; match page_validity { - None => page_values.gather_n_into(values, additional, &BitmapGatherer)?, + None => { + page_values.gather_n_into(values, additional, &BitmapGatherer)?; + + if is_optional { + validity.extend_constant(additional, true); + } + }, Some(page_validity) => utils::extend_from_decoder( validity, page_validity, @@ -165,6 +174,10 @@ impl<'a, 'b> BatchableCollector for BitmapCollector<'a, 'b> target.extend_constant(n, false); Ok(()) } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.0.skip_in_place(n) + } } impl ExactSize for (MutableBitmap, MutableBitmap) { @@ -194,17 +207,26 @@ impl Decoder for BooleanDecoder { ) } - fn deserialize_dict(&self, _: DictPage) -> Self::Dict {} + fn deserialize_dict(&self, _: DictPage) -> ParquetResult { + Ok(()) + } fn decode_plain_encoded<'a>( &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut as utils::StateTranslation<'a, Self>>::PlainDecoder, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, limit: usize, ) -> ParquetResult<()> { match page_validity { - None => page_values.collect_n_into(values, limit), + None => { + page_values.collect_n_into(values, limit); + + if is_optional { + validity.extend_constant(limit, true); + } + }, Some(page_validity) => { extend_from_decoder(validity, page_validity, Some(limit), values, page_values)? }, @@ -217,6 +239,7 @@ impl Decoder for BooleanDecoder { &mut self, _decoded: &mut Self::DecodedState, _page_values: &mut HybridRleDecoder<'a>, + _is_optional: bool, _page_validity: Option<&mut PageValidity<'a>>, _dict: &Self::Dict, _limit: usize, @@ -226,12 +249,12 @@ impl Decoder for BooleanDecoder { fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, _dict: Option, (values, validity): Self::DecodedState, ) -> ParquetResult { let validity = freeze_validity(validity); - Ok(BooleanArray::new(data_type, values.into(), validity)) + Ok(BooleanArray::new(dtype, values.into(), validity)) } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs index 4a7b8f740063..de2bfe2e47f3 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary.rs @@ -47,7 +47,9 @@ impl<'a, K: DictionaryKey, D: utils::DictDecodable> StateTranslation<'a, Diction &mut self, decoder: &mut DictionaryDecoder, decoded: &mut as Decoder>::DecodedState, + is_optional: bool, page_validity: &mut Option>, + _: Option<&'a as Decoder>::Dict>, additional: usize, ) -> ParquetResult<()> { let (values, validity) = decoded; @@ -64,7 +66,13 @@ impl<'a, K: DictionaryKey, D: utils::DictDecodable> StateTranslation<'a, Diction }; match page_validity { - None => collector.push_n(&mut decoded.0, additional)?, + None => { + collector.push_n(&mut decoded.0, additional)?; + + if is_optional { + validity.extend_constant(additional, true); + } + }, Some(page_validity) => { extend_from_decoder(validity, page_validity, Some(additional), values, collector)? }, @@ -104,16 +112,16 @@ impl utils::Decoder for DictionaryDec ) } - fn deserialize_dict(&self, page: DictPage) -> Self::Dict { - let dict = self.decoder.deserialize_dict(page); + fn deserialize_dict(&self, page: DictPage) -> ParquetResult { + let dict = self.decoder.deserialize_dict(page)?; self.dict_size .store(dict.len(), std::sync::atomic::Ordering::Relaxed); - dict + Ok(dict) } fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Option, (values, validity): Self::DecodedState, ) -> ParquetResult> { @@ -121,13 +129,14 @@ impl utils::Decoder for DictionaryDec let dict = dict.unwrap(); let keys = PrimitiveArray::new(K::PRIMITIVE.into(), values.into(), validity); - self.decoder.finalize_dict_array(data_type, dict, keys) + self.decoder.finalize_dict_array(dtype, dict, keys) } fn decode_plain_encoded<'a>( &mut self, _decoded: &mut Self::DecodedState, _page_values: &mut as StateTranslation<'a, Self>>::PlainDecoder, + _is_optional: bool, _page_validity: Option<&mut PageValidity<'a>>, _limit: usize, ) -> ParquetResult<()> { @@ -138,6 +147,7 @@ impl utils::Decoder for DictionaryDec &mut self, _decoded: &mut Self::DecodedState, _page_values: &mut HybridRleDecoder<'a>, + _is_optional: bool, _page_validity: Option<&mut PageValidity<'a>>, _dict: &Self::Dict, _limit: usize, @@ -191,6 +201,10 @@ impl<'a, 'b, K: DictionaryKey> BatchableCollector<(), Vec> for DictArrayColle target.resize(target.len() + n, K::default()); Ok(()) } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.values.skip_in_place(n) + } } impl Translator for DictArrayTranslator { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs index 747243ce26ef..3825d528c8f5 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary.rs @@ -13,9 +13,10 @@ use crate::read::deserialize::utils::{self, BatchableCollector, GatheredHybridRl #[derive(Debug)] pub(crate) enum StateTranslation<'a> { Plain(&'a [u8], usize), - Dictionary(hybrid_rle::HybridRleDecoder<'a>, &'a Vec), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), } +#[derive(Debug)] pub struct FixedSizeBinary { pub values: Vec, pub size: usize, @@ -42,9 +43,9 @@ impl<'a> utils::StateTranslation<'a, BinaryDecoder> for StateTranslation<'a> { } Ok(Self::Plain(values, decoder.size)) }, - (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict)) => { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { let values = dict_indices_decoder(page)?; - Ok(Self::Dictionary(values, dict)) + Ok(Self::Dictionary(values)) }, _ => Err(utils::not_implemented(page)), } @@ -53,7 +54,7 @@ impl<'a> utils::StateTranslation<'a, BinaryDecoder> for StateTranslation<'a> { fn len_when_not_nullable(&self) -> usize { match self { Self::Plain(v, size) => v.len() / size, - Self::Dictionary(v, _) => v.len(), + Self::Dictionary(v) => v.len(), } } @@ -64,7 +65,7 @@ impl<'a> utils::StateTranslation<'a, BinaryDecoder> for StateTranslation<'a> { match self { Self::Plain(v, size) => *v = &v[usize::min(v.len(), n * *size)..], - Self::Dictionary(v, _) => v.skip_in_place(n)?, + Self::Dictionary(v) => v.skip_in_place(n)?, } Ok(()) @@ -74,7 +75,9 @@ impl<'a> utils::StateTranslation<'a, BinaryDecoder> for StateTranslation<'a> { &mut self, decoder: &mut BinaryDecoder, decoded: &mut ::DecodedState, + is_optional: bool, page_validity: &mut Option>, + dict: Option<&'a ::Dict>, additional: usize, ) -> ParquetResult<()> { use StateTranslation as T; @@ -82,14 +85,16 @@ impl<'a> utils::StateTranslation<'a, BinaryDecoder> for StateTranslation<'a> { T::Plain(page_values, _) => decoder.decode_plain_encoded( decoded, page_values, + is_optional, page_validity.as_mut(), additional, )?, - T::Dictionary(page_values, dict) => decoder.decode_dictionary_encoded( + T::Dictionary(page_values) => decoder.decode_dictionary_encoded( decoded, page_values, + is_optional, page_validity.as_mut(), - dict, + dict.unwrap(), additional, )?, } @@ -132,14 +137,15 @@ impl Decoder for BinaryDecoder { ) } - fn deserialize_dict(&self, page: DictPage) -> Self::Dict { - page.buffer.into_vec() + fn deserialize_dict(&self, page: DictPage) -> ParquetResult { + Ok(page.buffer.into_vec()) } fn decode_plain_encoded<'a>( &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut as utils::StateTranslation<'a, Self>>::PlainDecoder, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, limit: usize, ) -> ParquetResult<()> { @@ -164,6 +170,12 @@ impl Decoder for BinaryDecoder { target.resize(target.len() + n * self.size, 0); Ok(()) } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + let n = usize::min(n, self.slice.len() / self.size); + *self.slice = &self.slice[n * self.size..]; + Ok(()) + } } let mut collector = FixedSizeBinaryCollector { @@ -172,7 +184,13 @@ impl Decoder for BinaryDecoder { }; match page_validity { - None => collector.push_n(&mut values.values, self.size)?, + None => { + collector.push_n(&mut values.values, limit)?; + + if is_optional { + validity.extend_constant(limit, true); + } + }, Some(page_validity) => extend_from_decoder( validity, page_validity, @@ -189,6 +207,7 @@ impl Decoder for BinaryDecoder { &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut hybrid_rle::HybridRleDecoder<'a>, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, dict: &Self::Dict, limit: usize, @@ -266,6 +285,10 @@ impl Decoder for BinaryDecoder { match page_validity { None => { page_values.gather_n_into(&mut values.values, limit, &gatherer)?; + + if is_optional { + validity.extend_constant(limit, true); + } }, Some(page_validity) => { let collector = GatheredHybridRle::new(page_values, &gatherer, null_value); @@ -285,13 +308,13 @@ impl Decoder for BinaryDecoder { fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, _dict: Option, (values, validity): Self::DecodedState, ) -> ParquetResult { let validity = freeze_validity(validity); Ok(FixedSizeBinaryArray::new( - data_type, + dtype, values.values.into(), validity, )) @@ -301,13 +324,13 @@ impl Decoder for BinaryDecoder { impl utils::DictDecodable for BinaryDecoder { fn finalize_dict_array( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Self::Dict, keys: PrimitiveArray, ) -> ParquetResult> { let dict = FixedSizeBinaryArray::new(ArrowDataType::FixedSizeBinary(self.size), dict.into(), None); - Ok(DictionaryArray::try_new(data_type, keys, Box::new(dict)).unwrap()) + Ok(DictionaryArray::try_new(dtype, keys, Box::new(dict)).unwrap()) } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs index 28c97f82ccf1..520f7f8596e1 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/mod.rs @@ -1,6 +1,5 @@ //! APIs to read from Parquet format. -mod binary; mod binview; mod boolean; mod dictionary; @@ -27,16 +26,14 @@ use crate::parquet::schema::types::PrimitiveType; /// Creates a new iterator of compressed pages. pub fn get_page_iterator( - column_metadata: &ColumnChunkMetaData, + column_metadata: &ColumnChunkMetadata, reader: MemReader, - pages_filter: Option, buffer: Vec, max_header_size: usize, ) -> PolarsResult { Ok(_get_page_iterator( column_metadata, reader, - pages_filter, buffer, max_header_size, )?) @@ -44,13 +41,13 @@ pub fn get_page_iterator( /// Creates a new [`ListArray`] or [`FixedSizeListArray`]. pub fn create_list( - data_type: ArrowDataType, + dtype: ArrowDataType, nested: &mut NestedState, values: Box, ) -> Box { let (mut offsets, validity) = nested.pop().unwrap(); let validity = validity.and_then(freeze_validity); - match data_type.to_logical_type() { + match dtype.to_logical_type() { ArrowDataType::List(_) => { offsets.push(values.len() as i64); @@ -61,7 +58,7 @@ pub fn create_list( .expect("i64 offsets do not fit in i32 offsets"); Box::new(ListArray::::new( - data_type, + dtype, offsets.into(), values, validity, @@ -71,14 +68,14 @@ pub fn create_list( offsets.push(values.len() as i64); Box::new(ListArray::::new( - data_type, + dtype, offsets.try_into().expect("List too large"), values, validity, )) }, ArrowDataType::FixedSizeList(_, _) => { - Box::new(FixedSizeListArray::new(data_type, values, validity)) + Box::new(FixedSizeListArray::new(dtype, values, validity)) }, _ => unreachable!(), } @@ -86,12 +83,12 @@ pub fn create_list( /// Creates a new [`MapArray`]. pub fn create_map( - data_type: ArrowDataType, + dtype: ArrowDataType, nested: &mut NestedState, values: Box, ) -> Box { let (mut offsets, validity) = nested.pop().unwrap(); - match data_type.to_logical_type() { + match dtype.to_logical_type() { ArrowDataType::Map(_, _) => { offsets.push(values.len() as i64); let offsets = offsets.iter().map(|x| *x as i32).collect::>(); @@ -101,7 +98,7 @@ pub fn create_map( .expect("i64 offsets do not fit in i32 offsets"); Box::new(MapArray::new( - data_type, + dtype, offsets.into(), values, validity.and_then(freeze_validity), @@ -111,9 +108,9 @@ pub fn create_map( } } -fn is_primitive(data_type: &ArrowDataType) -> bool { +fn is_primitive(dtype: &ArrowDataType) -> bool { matches!( - data_type.to_physical_type(), + dtype.to_physical_type(), arrow::datatypes::PhysicalType::Primitive(_) | arrow::datatypes::PhysicalType::Null | arrow::datatypes::PhysicalType::Boolean @@ -135,11 +132,11 @@ fn columns_to_iter_recursive( init: Vec, filter: Option, ) -> PolarsResult<(NestedState, Box)> { - if init.is_empty() && is_primitive(&field.data_type) { + if init.is_empty() && is_primitive(&field.dtype) { let array = page_iter_to_array( columns.pop().unwrap(), types.pop().unwrap(), - field.data_type, + field.dtype, filter, )?; @@ -150,34 +147,34 @@ fn columns_to_iter_recursive( } /// Returns the number of (parquet) columns that a [`ArrowDataType`] contains. -pub fn n_columns(data_type: &ArrowDataType) -> usize { +pub fn n_columns(dtype: &ArrowDataType) -> usize { use arrow::datatypes::PhysicalType::*; - match data_type.to_physical_type() { + match dtype.to_physical_type() { Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 | Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => 1, List | FixedSizeList | LargeList => { - let a = data_type.to_logical_type(); + let a = dtype.to_logical_type(); if let ArrowDataType::List(inner) = a { - n_columns(&inner.data_type) + n_columns(&inner.dtype) } else if let ArrowDataType::LargeList(inner) = a { - n_columns(&inner.data_type) + n_columns(&inner.dtype) } else if let ArrowDataType::FixedSizeList(inner, _) = a { - n_columns(&inner.data_type) + n_columns(&inner.dtype) } else { unreachable!() } }, Map => { - let a = data_type.to_logical_type(); + let a = dtype.to_logical_type(); if let ArrowDataType::Map(inner, _) = a { - n_columns(&inner.data_type) + n_columns(&inner.dtype) } else { unreachable!() } }, Struct => { - if let ArrowDataType::Struct(fields) = data_type.to_logical_type() { - fields.iter().map(|inner| n_columns(&inner.data_type)).sum() + if let ArrowDataType::Struct(fields) = dtype.to_logical_type() { + fields.iter().map(|inner| n_columns(&inner.dtype)).sum() } else { unreachable!() } @@ -191,14 +188,13 @@ pub fn n_columns(data_type: &ArrowDataType) -> usize { /// For a non-nested datatypes such as [`ArrowDataType::Int32`], this function requires a single element in `columns` and `types`. /// For nested types, `columns` must be composed by all parquet columns with associated types `types`. /// -/// The arrays are guaranteed to be at most of size `chunk_size` and data type `field.data_type`. -pub fn column_iter_to_arrays<'a>( +/// The arrays are guaranteed to be at most of size `chunk_size` and data type `field.dtype`. +pub fn column_iter_to_arrays( columns: Vec, types: Vec<&PrimitiveType>, field: Field, filter: Option, -) -> PolarsResult> { +) -> PolarsResult> { let (_, array) = columns_to_iter_recursive(columns, types, field, vec![], filter)?; - - Ok(Box::new(std::iter::once(Ok(array)))) + Ok(array) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index b82abec09996..114eeef67341 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -18,14 +18,14 @@ pub fn columns_to_iter_recursive( use arrow::datatypes::PhysicalType::*; use arrow::datatypes::PrimitiveType::*; - Ok(match field.data_type().to_physical_type() { + Ok(match field.dtype().to_physical_type() { Null => { // physical type is i32 init.push(InitNested::Primitive(field.is_nullable)); types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), + field.dtype().clone(), null::NullDecoder, init, )? @@ -49,8 +49,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -61,8 +61,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -73,8 +73,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::unit(), + field.dtype().clone(), + primitive::IntDecoder::::unit(), init, )? .collect_n(filter) @@ -85,8 +85,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::unit(), + field.dtype().clone(), + primitive::IntDecoder::::unit(), init, )? .collect_n(filter) @@ -97,8 +97,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -109,8 +109,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -122,8 +122,8 @@ pub fn columns_to_iter_recursive( match type_.physical_type { PhysicalType::Int32 => PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -131,8 +131,8 @@ pub fn columns_to_iter_recursive( // some implementations of parquet write arrow's u32 into i64. PhysicalType::Int64 => PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -149,8 +149,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::cast_as(), + field.dtype().clone(), + primitive::IntDecoder::::cast_as(), init, )? .collect_n(filter) @@ -161,8 +161,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::unit(), + field.dtype().clone(), + primitive::FloatDecoder::::unit(), init, )? .collect_n(filter) @@ -173,8 +173,8 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), - primitive::PrimitiveDecoder::::unit(), + field.dtype().clone(), + primitive::FloatDecoder::::unit(), init, )? .collect_n(filter) @@ -185,32 +185,23 @@ pub fn columns_to_iter_recursive( types.pop(); PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type().clone(), + field.dtype().clone(), binview::BinViewDecoder::default(), init, )? .collect_n(filter)? }, - LargeBinary | LargeUtf8 => { - init.push(InitNested::Primitive(field.is_nullable)); - types.pop(); - PageNestedDecoder::new( - columns.pop().unwrap(), - field.data_type().clone(), - binary::BinaryDecoder::::default(), - init, - )? - .collect_n(filter)? - }, - _ => match field.data_type().to_logical_type() { + // These are all converted to View variants before. + LargeBinary | LargeUtf8 | Binary | Utf8 => unreachable!(), + _ => match field.dtype().to_logical_type() { ArrowDataType::Dictionary(key_type, _, _) => { init.push(InitNested::Primitive(field.is_nullable)); let type_ = types.pop().unwrap(); let iter = columns.pop().unwrap(); - let data_type = field.data_type().clone(); + let dtype = field.dtype().clone(); match_integer_type!(key_type, |$K| { - dict_read::<$K>(iter, init, type_, data_type, filter).map(|(s, arr)| (s, Box::new(arr) as Box<_>)) + dict_read::<$K>(iter, init, type_, dtype, filter).map(|(s, arr)| (s, Box::new(arr) as Box<_>)) })? }, ArrowDataType::List(inner) | ArrowDataType::LargeList(inner) => { @@ -222,7 +213,7 @@ pub fn columns_to_iter_recursive( init, filter, )?; - let array = create_list(field.data_type().clone(), &mut nested, array); + let array = create_list(field.dtype().clone(), &mut nested, array); (nested, array) }, ArrowDataType::FixedSizeList(inner, width) => { @@ -234,7 +225,7 @@ pub fn columns_to_iter_recursive( init, filter, )?; - let array = create_list(field.data_type().clone(), &mut nested, array); + let array = create_list(field.dtype().clone(), &mut nested, array); (nested, array) }, ArrowDataType::Decimal(_, _) => { @@ -243,16 +234,16 @@ pub fn columns_to_iter_recursive( match type_.physical_type { PhysicalType::Int32 => PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type.clone(), - primitive::PrimitiveDecoder::::cast_into(), + field.dtype.clone(), + primitive::IntDecoder::::cast_into(), init, )? .collect_n(filter) .map(|(s, a)| (s, Box::new(a) as Box<_>))?, PhysicalType::Int64 => PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type.clone(), - primitive::PrimitiveDecoder::::cast_into(), + field.dtype.clone(), + primitive::IntDecoder::::cast_into(), init, )? .collect_n(filter) @@ -280,7 +271,7 @@ pub fn columns_to_iter_recursive( let validity = array.validity().cloned(); let array: Box = Box::new(PrimitiveArray::::try_new( - field.data_type.clone(), + field.dtype.clone(), values.into(), validity, )?); @@ -301,16 +292,16 @@ pub fn columns_to_iter_recursive( match type_.physical_type { PhysicalType::Int32 => PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type.clone(), - primitive::PrimitiveDecoder::closure(|x: i32| i256(I256::new(x as i128))), + field.dtype.clone(), + primitive::IntDecoder::closure(|x: i32| i256(I256::new(x as i128))), init, )? .collect_n(filter) .map(|(s, a)| (s, Box::new(a) as Box<_>))?, PhysicalType::Int64 => PageNestedDecoder::new( columns.pop().unwrap(), - field.data_type.clone(), - primitive::PrimitiveDecoder::closure(|x: i64| i256(I256::new(x as i128))), + field.dtype.clone(), + primitive::IntDecoder::closure(|x: i64| i256(I256::new(x as i128))), init, )? .collect_n(filter) @@ -333,7 +324,7 @@ pub fn columns_to_iter_recursive( let validity = array.validity().cloned(); let array: Box = Box::new(PrimitiveArray::::try_new( - field.data_type.clone(), + field.dtype.clone(), values.into(), validity, )?); @@ -359,7 +350,7 @@ pub fn columns_to_iter_recursive( let validity = array.validity().cloned(); let array: Box = Box::new(PrimitiveArray::::try_new( - field.data_type.clone(), + field.dtype.clone(), values.into(), validity, )?); @@ -396,7 +387,7 @@ pub fn columns_to_iter_recursive( types: &mut Vec<&PrimitiveType>, struct_field: &Field| { init.push(InitNested::Struct(field.is_nullable)); - let n = n_columns(&struct_field.data_type); + let n = n_columns(&struct_field.dtype); let columns = columns.split_off(columns.len() - n); let types = types.split_off(types.len() - n); @@ -454,7 +445,7 @@ pub fn columns_to_iter_recursive( init, filter, )?; - let array = create_map(field.data_type().clone(), &mut nested, array); + let array = create_map(field.dtype().clone(), &mut nested, array); (nested, array) }, other => { @@ -470,106 +461,85 @@ fn dict_read( iter: BasicDecompressor, init: Vec, _type_: &PrimitiveType, - data_type: ArrowDataType, + dtype: ArrowDataType, filter: Option, ) -> PolarsResult<(NestedState, DictionaryArray)> { use ArrowDataType::*; - let values_data_type = if let Dictionary(_, v, _) = &data_type { + let values_dtype = if let Dictionary(_, v, _) = &dtype { v.as_ref() } else { panic!() }; - Ok(match values_data_type.to_logical_type() { - UInt8 => { - PageNestedDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - init, - )? - .collect_n(filter)? - }, + Ok(match values_dtype.to_logical_type() { + UInt8 => PageNestedDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::cast_as()), + init, + )? + .collect_n(filter)?, UInt16 => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::cast_as()), init, )? .collect_n(filter)?, UInt32 => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::cast_as()), + init, + )? + .collect_n(filter)?, + Int8 => PageNestedDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::cast_as()), init, )? .collect_n(filter)?, - Int8 => { - PageNestedDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - init, - )? - .collect_n(filter)? - }, Int16 => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::cast_as()), init, )? .collect_n(filter)?, Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new(primitive::PrimitiveDecoder::::unit()), + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::unit()), init, )? .collect_n(filter)?, Int64 | Date64 | Time64(_) | Duration(_) => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), + dtype, + dictionary::DictionaryDecoder::new(primitive::IntDecoder::::cast_as()), init, )? .collect_n(filter)?, Float32 => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new(primitive::PrimitiveDecoder::::unit()), + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::unit()), init, )? .collect_n(filter)?, Float64 => PageNestedDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new(primitive::PrimitiveDecoder::::unit()), - init, - )? - .collect_n(filter)?, - LargeUtf8 | LargeBinary => PageNestedDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new(binary::BinaryDecoder::::default()), + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::unit()), init, )? .collect_n(filter)?, + // These are all converted to View variants before. + LargeUtf8 | LargeBinary | Utf8 | Binary => unreachable!(), Utf8View | BinaryView => PageNestedDecoder::new( iter, - data_type, + dtype, dictionary::DictionaryDecoder::new(binview::BinViewDecoder::default()), init, )? @@ -578,7 +548,7 @@ fn dict_read( let size = *size; PageNestedDecoder::new( iter, - data_type, + dtype, dictionary::DictionaryDecoder::new(fixed_size_binary::BinaryDecoder { size }), init, )? @@ -592,7 +562,7 @@ fn dict_read( iter, physical_type, logical_type, - data_type, + dtype, chunk_size, time_unit, ); diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index 217d9f694510..ad542cf05753 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -4,14 +4,12 @@ use polars_error::PolarsResult; use super::utils::{self, BatchableCollector}; use super::{BasicDecompressor, Filter}; -use crate::parquet::encoding::hybrid_rle::gatherer::{ - HybridRleGatherer, ZeroCount, ZeroCountGatherer, -}; +use crate::parquet::encoding::hybrid_rle::gatherer::HybridRleGatherer; use crate::parquet::encoding::hybrid_rle::HybridRleDecoder; use crate::parquet::error::ParquetResult; use crate::parquet::page::{split_buffer, DataPage}; use crate::parquet::read::levels::get_bit_width; -use crate::read::deserialize::utils::BatchedCollector; +use crate::read::deserialize::utils::{hybrid_rle_count_zeros, BatchedCollector}; #[derive(Debug)] pub struct Nested { @@ -140,7 +138,7 @@ impl Nested { fn invalid_num_values(&self) -> usize { match &self.content { - NestedContent::Primitive => 0, + NestedContent::Primitive => 1, NestedContent::List { .. } => 0, NestedContent::FixedSizeList { width } => *width, NestedContent::Struct => 1, @@ -204,6 +202,10 @@ impl<'a, 'b, 'c, D: utils::NestedDecoder> BatchableCollector<(), D::DecodedState self.decoder.push_n_nulls(self.state, target, n); Ok(()) } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.state.skip_in_place(n) + } } /// The initial info of nested data types. @@ -290,6 +292,67 @@ impl NestedState { } } +/// Calculate the number of leaf values that are covered by the first `limit` definition level +/// values. +fn limit_to_num_values( + def_iter: &HybridRleDecoder<'_>, + def_levels: &[u16], + limit: usize, +) -> ParquetResult { + struct NumValuesGatherer { + leaf_def_level: u16, + } + struct NumValuesState { + num_values: usize, + length: usize, + } + + impl HybridRleGatherer for NumValuesGatherer { + type Target = NumValuesState; + + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} + + fn target_num_elements(&self, target: &Self::Target) -> usize { + target.length + } + + fn hybridrle_to_target(&self, value: u32) -> ParquetResult { + Ok(value) + } + + fn gather_one(&self, target: &mut Self::Target, value: u32) -> ParquetResult<()> { + target.num_values += usize::from(value == self.leaf_def_level as u32); + target.length += 1; + Ok(()) + } + + fn gather_repeated( + &self, + target: &mut Self::Target, + value: u32, + n: usize, + ) -> ParquetResult<()> { + target.num_values += n * usize::from(value == self.leaf_def_level as u32); + target.length += n; + Ok(()) + } + } + + let mut state = NumValuesState { + num_values: 0, + length: 0, + }; + def_iter.clone().gather_n_into( + &mut state, + limit, + &NumValuesGatherer { + leaf_def_level: *def_levels.last().unwrap(), + }, + )?; + + Ok(state.num_values) +} + fn idx_to_limit(rep_iter: &HybridRleDecoder<'_>, idx: usize) -> ParquetResult { struct RowIdxOffsetGatherer; struct RowIdxOffsetState { @@ -384,18 +447,22 @@ fn extend_offsets2<'a, D: utils::NestedDecoder>( >, nested: &mut [Nested], filter: Option, - // Amortized allocations + def_levels: &[u16], rep_levels: &[u16], ) -> PolarsResult<()> { + debug_assert_eq!(def_iter.len(), rep_iter.len()); + match filter { None => { + let limit = def_iter.len(); + extend_offsets_limited( &mut def_iter, &mut rep_iter, batched_collector, nested, - usize::MAX, + limit, def_levels, rep_levels, )?; @@ -412,6 +479,9 @@ fn extend_offsets2<'a, D: utils::NestedDecoder>( if start > 0 { let start_cell = idx_to_limit(&rep_iter, start)?; + let num_skipped_values = limit_to_num_values(&def_iter, def_levels, start_cell)?; + batched_collector.skip_in_place(num_skipped_values)?; + rep_iter.skip_in_place(start_cell)?; def_iter.skip_in_place(start_cell)?; } @@ -432,6 +502,8 @@ fn extend_offsets2<'a, D: utils::NestedDecoder>( // @NOTE: This is kind of unused let last_skip = def_iter.len(); + let num_skipped_values = limit_to_num_values(&def_iter, def_levels, last_skip)?; + batched_collector.skip_in_place(num_skipped_values)?; rep_iter.skip_in_place(last_skip)?; def_iter.skip_in_place(last_skip)?; @@ -443,6 +515,8 @@ fn extend_offsets2<'a, D: utils::NestedDecoder>( let num_zeros = iter.take_leading_zeros(); if num_zeros > 0 { let offset = idx_to_limit(&rep_iter, num_zeros)?; + let num_skipped_values = limit_to_num_values(&def_iter, def_levels, offset)?; + batched_collector.skip_in_place(num_skipped_values)?; rep_iter.skip_in_place(offset)?; def_iter.skip_in_place(offset)?; } @@ -461,6 +535,7 @@ fn extend_offsets2<'a, D: utils::NestedDecoder>( )?; } } + Ok(()) }, } @@ -597,23 +672,16 @@ fn extend_offsets_limited<'a, D: utils::NestedDecoder>( } } - if embed_depth == max_depth - 1 { - for _ in 0..num_elements { - batched_collector.push_invalid(); - } - - break; - } - let embed_num_values = embed_nest.invalid_num_values(); + num_elements *= embed_num_values; if embed_num_values == 0 { break; } - - num_elements *= embed_num_values; } + batched_collector.push_n_invalids(num_elements); + break; } @@ -645,7 +713,7 @@ fn extend_offsets_limited<'a, D: utils::NestedDecoder>( pub struct PageNestedDecoder { pub iter: BasicDecompressor, - pub data_type: ArrowDataType, + pub dtype: ArrowDataType, pub dict: Option, pub decoder: D, pub init: Vec, @@ -669,16 +737,16 @@ fn level_iters(page: &DataPage) -> ParquetResult<(HybridRleDecoder, HybridRleDec impl PageNestedDecoder { pub fn new( mut iter: BasicDecompressor, - data_type: ArrowDataType, + dtype: ArrowDataType, decoder: D, init: Vec, ) -> ParquetResult { let dict_page = iter.read_dict_page()?; - let dict = dict_page.map(|d| decoder.deserialize_dict(d)); + let dict = dict_page.map(|d| decoder.deserialize_dict(d)).transpose()?; Ok(Self { iter, - data_type, + dtype, dict, decoder, init, @@ -691,6 +759,10 @@ impl PageNestedDecoder { // @TODO: Self capacity let mut nested_state = init_nested(&self.init, 0); + if let Some(dict) = self.dict.as_ref() { + self.decoder.apply_dictionary(&mut target, dict)?; + } + // Amortize the allocations. let (def_levels, rep_levels) = nested_state.levels(); @@ -701,6 +773,7 @@ impl PageNestedDecoder { break; }; let page = page?; + let page = page.decompress(&mut self.iter)?; let mut state = utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; @@ -732,49 +805,147 @@ impl PageNestedDecoder { } }, Some(mut filter) => { + enum PageStartAction { + Skip, + Collect, + } + + // We may have an action (skip / collect) for one row value left over from the + // previous page. Every page may state what the next page needs to do until the + // first of its own row values (rep_lvl = 0). + let mut last_row_value_action = PageStartAction::Skip; let mut num_rows_remaining = filter.num_rows(); - loop { + while num_rows_remaining > 0 + || matches!(last_row_value_action, PageStartAction::Collect) + { let Some(page) = self.iter.next() else { break; }; let page = page?; + // We cannot lazily decompress because we don't have the number of row values + // at this point. We need repetition levels for that. *sign*. In general, lazy + // decompression is quite difficult with nested values. + // + // @TODO + // Lazy decompression is quite doable in the V2 specification since that does + // not compress the repetition and definition levels. However, not a lot of + // people use the V2 specification. So let us ignore that for now. + let page = page.decompress(&mut self.iter)?; - let mut state = - utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; - let (def_iter, rep_iter) = level_iters(&page)?; + let (mut def_iter, mut rep_iter) = level_iters(&page)?; - let mut count = ZeroCount::default(); - rep_iter - .clone() - .gather_into(&mut count, &ZeroCountGatherer)?; + let mut state; + let mut batched_collector; + + let start_length = nested_state.len(); + + // rep lvl == 0 ==> row value + let num_row_values = hybrid_rle_count_zeros(&rep_iter)?; - let is_fully_read = count.num_zero > num_rows_remaining; let state_filter; - (state_filter, filter) = Filter::split_at(&filter, count.num_zero); - let state_filter = if count.num_zero > 0 { - Some(state_filter) - } else { - None - }; + (state_filter, filter) = Filter::split_at(&filter, num_row_values); + + match last_row_value_action { + PageStartAction::Skip => { + // Fast path: skip the whole page. + // No new row values or we don't care about any of the row values. + if num_row_values == 0 && state_filter.num_rows() == 0 { + self.iter.reuse_page_buffer(page); + continue; + } - let start_length = nested_state.len(); + let limit = idx_to_limit(&rep_iter, 0)?; + + // We just saw that we had at least one row value. + debug_assert!(limit < rep_iter.len()); + + state = + utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; + batched_collector = BatchedCollector::new( + BatchedNestedDecoder { + state: &mut state, + decoder: &mut self.decoder, + }, + &mut target, + ); + + let num_leaf_values = + limit_to_num_values(&def_iter, &def_levels, limit)?; + batched_collector.skip_in_place(num_leaf_values)?; + rep_iter.skip_in_place(limit)?; + def_iter.skip_in_place(limit)?; + }, + PageStartAction::Collect => { + let limit = if num_row_values == 0 { + rep_iter.len() + } else { + idx_to_limit(&rep_iter, 0)? + }; + + // Fast path: we are not interested in any of the row values in this + // page. + if limit == 0 && state_filter.num_rows() == 0 { + self.iter.reuse_page_buffer(page); + last_row_value_action = PageStartAction::Skip; + continue; + } - // @TODO: move this to outside the loop. - let mut batched_collector = BatchedCollector::new( - BatchedNestedDecoder { - state: &mut state, - decoder: &mut self.decoder, + state = + utils::State::new_nested(&self.decoder, &page, self.dict.as_ref())?; + batched_collector = BatchedCollector::new( + BatchedNestedDecoder { + state: &mut state, + decoder: &mut self.decoder, + }, + &mut target, + ); + + extend_offsets_limited( + &mut def_iter, + &mut rep_iter, + &mut batched_collector, + &mut nested_state.nested, + limit, + &def_levels, + &rep_levels, + )?; + + // No new row values. Keep collecting. + if rep_iter.len() == 0 { + batched_collector.finalize()?; + + let num_done = nested_state.len() - start_length; + debug_assert!(num_done <= num_rows_remaining); + debug_assert!(num_done <= num_row_values); + num_rows_remaining -= num_done; + + drop(state); + self.iter.reuse_page_buffer(page); + + continue; + } }, - &mut target, - ); + } + + // Two cases: + // 1. First page: Must always start with a row value. + // 2. Other pages: If they did not have a row value, they would have been + // handled by the last_row_value_action. + debug_assert!(num_row_values > 0); + + last_row_value_action = if state_filter.do_include_at(num_row_values - 1) { + PageStartAction::Collect + } else { + PageStartAction::Skip + }; extend_offsets2( def_iter, rep_iter, &mut batched_collector, &mut nested_state.nested, - state_filter, + Some(state_filter), &def_levels, &rep_levels, )?; @@ -783,15 +954,11 @@ impl PageNestedDecoder { let num_done = nested_state.len() - start_length; debug_assert!(num_done <= num_rows_remaining); - debug_assert!(num_done <= count.num_zero); + debug_assert!(num_done <= num_row_values); num_rows_remaining -= num_done; drop(state); self.iter.reuse_page_buffer(page); - - if is_fully_read { - break; - } } }, } @@ -803,7 +970,7 @@ impl PageNestedDecoder { )); _ = nested_state.pop().unwrap(); - let array = self.decoder.finalize(self.data_type, self.dict, target)?; + let array = self.decoder.finalize(self.dtype, self.dict, target)?; Ok((nested_state, array)) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/null.rs b/crates/polars-parquet/src/arrow/read/deserialize/null.rs index 74defc1d3b74..e12757fe2e20 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/null.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/null.rs @@ -12,6 +12,7 @@ use crate::parquet::error::ParquetResult; use crate::parquet::page::{DataPage, DictPage}; pub(crate) struct NullDecoder; +#[derive(Debug)] pub(crate) struct NullArrayLength { length: usize, } @@ -46,7 +47,9 @@ impl<'a> utils::StateTranslation<'a, NullDecoder> for () { &mut self, _decoder: &mut NullDecoder, decoded: &mut ::DecodedState, + _is_optional: bool, _page_validity: &mut Option>, + _: Option<&'a ::Dict>, additional: usize, ) -> ParquetResult<()> { decoded.length += additional; @@ -65,12 +68,15 @@ impl utils::Decoder for NullDecoder { NullArrayLength { length: 0 } } - fn deserialize_dict(&self, _: DictPage) -> Self::Dict {} + fn deserialize_dict(&self, _: DictPage) -> ParquetResult { + Ok(()) + } fn decode_plain_encoded<'a>( &mut self, _decoded: &mut Self::DecodedState, _page_values: &mut as utils::StateTranslation<'a, Self>>::PlainDecoder, + _is_optional: bool, _page_validity: Option<&mut utils::PageValidity<'a>>, _limit: usize, ) -> ParquetResult<()> { @@ -81,6 +87,7 @@ impl utils::Decoder for NullDecoder { &mut self, _decoded: &mut Self::DecodedState, _page_values: &mut hybrid_rle::HybridRleDecoder<'a>, + _is_optional: bool, _page_validity: Option<&mut utils::PageValidity<'a>>, _dict: &Self::Dict, _limit: usize, @@ -90,11 +97,11 @@ impl utils::Decoder for NullDecoder { fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, _dict: Option, decoded: Self::DecodedState, ) -> ParquetResult { - Ok(NullArray::new(data_type, decoded.length)) + Ok(NullArray::new(dtype, decoded.length)) } } @@ -121,9 +128,11 @@ use super::BasicDecompressor; /// Converts [`PagesIter`] to an [`ArrayIter`] pub fn iter_to_arrays( mut iter: BasicDecompressor, - data_type: ArrowDataType, + dtype: ArrowDataType, mut filter: Option, ) -> ParquetResult> { + _ = iter.read_dict_page()?; + let num_rows = Filter::opt_num_rows(&filter, iter.total_num_values()); let mut len = 0usize; @@ -134,19 +143,21 @@ pub fn iter_to_arrays( }; let page = page?; - let rows = page.num_values(); - let page_filter; - (page_filter, filter) = Filter::opt_split_at(&filter, rows); + let state_filter; + (state_filter, filter) = Filter::opt_split_at(&filter, page.num_values()); - let num_rows = match page_filter { - None => rows, + // Skip the whole page if we don't need any rows from it + if state_filter.as_ref().is_some_and(|f| f.num_rows() == 0) { + continue; + } + + let num_rows = match state_filter { + None => page.num_values(), Some(filter) => filter.num_rows(), }; len = (len + num_rows).min(num_rows); - - iter.reuse_page_buffer(page); } - Ok(Box::new(NullArray::new(data_type, len))) + Ok(Box::new(NullArray::new(dtype, len))) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs similarity index 55% rename from crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs rename to crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs index ce658b764412..0a43141abd06 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/float.rs @@ -2,9 +2,12 @@ use arrow::array::{DictionaryArray, DictionaryKey, PrimitiveArray}; use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; -use polars_error::PolarsResult; use super::super::utils; +use super::{ + deserialize_plain, AsDecoderFunction, ClosureDecoderFunction, DecoderFunction, + PlainDecoderFnCollector, PrimitiveDecoder, UnitDecoderFunction, +}; use crate::parquet::encoding::hybrid_rle::DictionaryTranslator; use crate::parquet::encoding::{byte_stream_split, hybrid_rle, Encoding}; use crate::parquet::error::ParquetResult; @@ -12,153 +15,19 @@ use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::parquet::types::{decode, NativeType as ParquetNativeType}; use crate::read::deserialize::utils::array_chunks::ArrayChunks; use crate::read::deserialize::utils::{ - freeze_validity, BatchableCollector, Decoder, PageValidity, TranslatedHybridRle, + dict_indices_decoder, freeze_validity, BatchableCollector, Decoder, PageValidity, + TranslatedHybridRle, }; -#[derive(Debug)] -pub(crate) struct ValuesDictionary<'a, T: NativeType> { - pub values: hybrid_rle::HybridRleDecoder<'a>, - pub dict: &'a Vec, -} - -impl<'a, T: NativeType> ValuesDictionary<'a, T> { - pub fn try_new(page: &'a DataPage, dict: &'a Vec) -> PolarsResult { - let values = utils::dict_indices_decoder(page)?; - - Ok(Self { dict, values }) - } - - #[inline] - pub fn len(&self) -> usize { - self.values.len() - } -} - -/// A function that defines how to decode from the -/// [`parquet::types::NativeType`][ParquetNativeType] to the [`arrow::types::NativeType`]. -/// -/// This should almost always be inlined. -pub(crate) trait DecoderFunction: Copy -where - T: NativeType, - P: ParquetNativeType, -{ - fn decode(self, x: P) -> T; -} - -#[derive(Default, Clone, Copy)] -pub(crate) struct UnitDecoderFunction(std::marker::PhantomData); -impl DecoderFunction for UnitDecoderFunction { - #[inline(always)] - fn decode(self, x: T) -> T { - x - } -} - -#[derive(Default, Clone, Copy)] -pub(crate) struct AsDecoderFunction(std::marker::PhantomData<(P, T)>); -macro_rules! as_decoder_impl { - ($($p:ty => $t:ty,)+) => { - $( - impl DecoderFunction<$p, $t> for AsDecoderFunction<$p, $t> { - #[inline(always)] - fn decode(self, x : $p) -> $t { - x as $t - } - } - )+ - }; -} - -as_decoder_impl![ - i32 => i8, - i32 => i16, - i32 => u8, - i32 => u16, - i32 => u32, - i64 => i32, - i64 => u32, - i64 => u64, -]; - -#[derive(Default, Clone, Copy)] -pub(crate) struct IntoDecoderFunction(std::marker::PhantomData<(P, T)>); -impl DecoderFunction for IntoDecoderFunction -where - P: ParquetNativeType + Into, - T: NativeType, -{ - #[inline(always)] - fn decode(self, x: P) -> T { - x.into() - } -} - -#[derive(Clone, Copy)] -pub(crate) struct ClosureDecoderFunction(F, std::marker::PhantomData<(P, T)>); -impl DecoderFunction for ClosureDecoderFunction -where - P: ParquetNativeType, - T: NativeType, - F: Copy + Fn(P) -> T, -{ - #[inline(always)] - fn decode(self, x: P) -> T { - (self.0)(x) - } -} - -pub(crate) struct PlainDecoderFnCollector<'a, 'b, P, T, D> -where - T: NativeType, - P: ParquetNativeType, - D: DecoderFunction, -{ - pub(crate) chunks: &'b mut ArrayChunks<'a, P>, - pub(crate) decoder: D, - pub(crate) _pd: std::marker::PhantomData, -} - -impl<'a, 'b, P, T, D: DecoderFunction> BatchableCollector<(), Vec> - for PlainDecoderFnCollector<'a, 'b, P, T, D> -where - T: NativeType, - P: ParquetNativeType, - D: DecoderFunction, -{ - fn reserve(target: &mut Vec, n: usize) { - target.reserve(n); - } - - fn push_n(&mut self, target: &mut Vec, n: usize) -> ParquetResult<()> { - let n = usize::min(self.chunks.len(), n); - let (items, remainder) = self.chunks.bytes.split_at(n); - let decoder = self.decoder; - target.extend( - items - .iter() - .map(|chunk| decoder.decode(P::from_le_bytes(*chunk))), - ); - self.chunks.bytes = remainder; - Ok(()) - } - - fn push_n_nulls(&mut self, target: &mut Vec, n: usize) -> ParquetResult<()> { - target.resize(target.len() + n, T::default()); - Ok(()) - } -} - #[allow(clippy::large_enum_variant)] #[derive(Debug)] -pub(crate) enum StateTranslation<'a, P: ParquetNativeType, T: NativeType> { +pub(crate) enum StateTranslation<'a, P: ParquetNativeType> { Plain(ArrayChunks<'a, P>), - Dictionary(ValuesDictionary<'a, T>), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), ByteStreamSplit(byte_stream_split::Decoder<'a>), } -impl<'a, P, T, D> utils::StateTranslation<'a, PrimitiveDecoder> - for StateTranslation<'a, P, T> +impl<'a, P, T, D> utils::StateTranslation<'a, FloatDecoder> for StateTranslation<'a, P> where T: NativeType, P: ParquetNativeType, @@ -167,14 +36,15 @@ where type PlainDecoder = ArrayChunks<'a, P>; fn new( - _decoder: &PrimitiveDecoder, + _decoder: &FloatDecoder, page: &'a DataPage, - dict: Option<&'a as utils::Decoder>::Dict>, + dict: Option<&'a as utils::Decoder>::Dict>, _page_validity: Option<&PageValidity<'a>>, ) -> ParquetResult { match (page.encoding(), dict) { - (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict)) => { - Ok(Self::Dictionary(ValuesDictionary::try_new(page, dict)?)) + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = dict_indices_decoder(page)?; + Ok(Self::Dictionary(values)) }, (Encoding::Plain, _) => { let values = split_buffer(page)?.values; @@ -206,8 +76,8 @@ where } match self { - Self::Plain(t) => _ = t.nth(n - 1), - Self::Dictionary(t) => t.values.skip_in_place(n)?, + Self::Plain(t) => t.skip_in_place(n), + Self::Dictionary(t) => t.skip_in_place(n)?, Self::ByteStreamSplit(t) => _ = t.iter_converted(|_| ()).nth(n - 1), } @@ -216,23 +86,27 @@ where fn extend_from_state( &mut self, - decoder: &mut PrimitiveDecoder, - decoded: &mut as utils::Decoder>::DecodedState, + decoder: &mut FloatDecoder, + decoded: &mut as utils::Decoder>::DecodedState, + is_optional: bool, page_validity: &mut Option>, + dict: Option<&'a as utils::Decoder>::Dict>, additional: usize, ) -> ParquetResult<()> { match self { Self::Plain(page_values) => decoder.decode_plain_encoded( decoded, page_values, + is_optional, page_validity.as_mut(), additional, )?, - Self::Dictionary(page) => decoder.decode_dictionary_encoded( + Self::Dictionary(ref mut page) => decoder.decode_dictionary_encoded( decoded, - &mut page.values, + page, + is_optional, page_validity.as_mut(), - page.dict, + dict.unwrap(), additional, )?, Self::ByteStreamSplit(page_values) => { @@ -242,16 +116,20 @@ where None => { values.extend( page_values - .iter_converted(|v| decoder.decoder.decode(decode(v))) + .iter_converted(|v| decoder.0.decoder.decode(decode(v))) .take(additional), ); + + if is_optional { + validity.extend_constant(additional, true); + } }, Some(page_validity) => utils::extend_from_decoder( validity, page_validity, Some(additional), values, - &mut page_values.iter_converted(|v| decoder.decoder.decode(decode(v))), + &mut page_values.iter_converted(|v| decoder.0.decoder.decode(decode(v))), )?, } }, @@ -262,17 +140,13 @@ where } #[derive(Debug)] -pub(crate) struct PrimitiveDecoder +pub(crate) struct FloatDecoder(PrimitiveDecoder) where P: ParquetNativeType, T: NativeType, - D: DecoderFunction, -{ - pub(crate) decoder: D, - _pd: std::marker::PhantomData<(P, T)>, -} + D: DecoderFunction; -impl PrimitiveDecoder +impl FloatDecoder where P: ParquetNativeType, T: NativeType, @@ -280,14 +154,11 @@ where { #[inline] fn new(decoder: D) -> Self { - Self { - decoder, - _pd: std::marker::PhantomData, - } + Self(PrimitiveDecoder::new(decoder)) } } -impl PrimitiveDecoder> +impl FloatDecoder> where T: NativeType + ParquetNativeType, UnitDecoderFunction: Default + DecoderFunction, @@ -297,7 +168,7 @@ where } } -impl PrimitiveDecoder> +impl FloatDecoder> where P: ParquetNativeType, T: NativeType, @@ -308,18 +179,7 @@ where } } -impl PrimitiveDecoder> -where - P: ParquetNativeType, - T: NativeType, - IntoDecoderFunction: Default + DecoderFunction, -{ - pub(crate) fn cast_into() -> Self { - Self::new(IntoDecoderFunction::::default()) - } -} - -impl PrimitiveDecoder> +impl FloatDecoder> where P: ParquetNativeType, T: NativeType, @@ -336,13 +196,13 @@ impl utils::ExactSize for (Vec, MutableBitmap) { } } -impl utils::Decoder for PrimitiveDecoder +impl utils::Decoder for FloatDecoder where T: NativeType, P: ParquetNativeType, D: DecoderFunction, { - type Translation<'a> = StateTranslation<'a, P, T>; + type Translation<'a> = StateTranslation<'a, P>; type Dict = Vec; type DecodedState = (Vec, MutableBitmap); type Output = PrimitiveArray; @@ -354,14 +214,15 @@ where ) } - fn deserialize_dict(&self, page: DictPage) -> Self::Dict { - deserialize_plain::(&page.buffer, self.decoder) + fn deserialize_dict(&self, page: DictPage) -> ParquetResult { + Ok(deserialize_plain::(&page.buffer, self.0.decoder)) } fn decode_plain_encoded<'a>( &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut as utils::StateTranslation<'a, Self>>::PlainDecoder, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, limit: usize, ) -> ParquetResult<()> { @@ -369,15 +230,19 @@ where None => { PlainDecoderFnCollector { chunks: page_values, - decoder: self.decoder, + decoder: self.0.decoder, _pd: std::marker::PhantomData, } .push_n(values, limit)?; + + if is_optional { + validity.extend_constant(limit, true); + } }, Some(page_validity) => { let collector = PlainDecoderFnCollector { chunks: page_values, - decoder: self.decoder, + decoder: self.0.decoder, _pd: std::marker::PhantomData, }; @@ -398,6 +263,7 @@ where &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut hybrid_rle::HybridRleDecoder<'a>, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, dict: &Self::Dict, limit: usize, @@ -407,6 +273,10 @@ where match page_validity { None => { page_values.translate_and_collect_n_into(values, limit, &translator)?; + + if is_optional { + validity.extend_constant(limit, true); + } }, Some(page_validity) => { let translated_hybridrle = TranslatedHybridRle::new(page_values, &translator); @@ -426,16 +296,16 @@ where fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, _dict: Option, (values, validity): Self::DecodedState, ) -> ParquetResult { let validity = freeze_validity(validity); - Ok(PrimitiveArray::try_new(data_type, values.into(), validity).unwrap()) + Ok(PrimitiveArray::try_new(dtype, values.into(), validity).unwrap()) } } -impl utils::DictDecodable for PrimitiveDecoder +impl utils::DictDecodable for FloatDecoder where T: NativeType, P: ParquetNativeType, @@ -443,22 +313,22 @@ where { fn finalize_dict_array( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Self::Dict, keys: PrimitiveArray, ) -> ParquetResult> { - let value_type = match &data_type { + let value_type = match &dtype { ArrowDataType::Dictionary(_, value, _) => value.as_ref().clone(), _ => T::PRIMITIVE.into(), }; let dict = Box::new(PrimitiveArray::new(value_type, dict.into(), None)); - Ok(DictionaryArray::try_new(data_type, keys, dict).unwrap()) + Ok(DictionaryArray::try_new(dtype, keys, dict).unwrap()) } } -impl utils::NestedDecoder for PrimitiveDecoder +impl utils::NestedDecoder for FloatDecoder where T: NativeType, P: ParquetNativeType, @@ -481,16 +351,3 @@ where values.resize(values.len() + n, T::default()); } } - -pub(super) fn deserialize_plain(values: &[u8], decoder: D) -> Vec -where - T: NativeType, - P: ParquetNativeType, - D: DecoderFunction, -{ - values - .chunks_exact(std::mem::size_of::

()) - .map(decode) - .map(|v| decoder.decode(v)) - .collect::>() -} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs index 10aeb5b9f640..ed8e0a541a68 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/integer.rs @@ -2,12 +2,12 @@ use arrow::array::{DictionaryArray, DictionaryKey, PrimitiveArray}; use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; -use num_traits::AsPrimitive; use super::super::utils; -use super::basic::{ - AsDecoderFunction, ClosureDecoderFunction, DecoderFunction, IntoDecoderFunction, - PlainDecoderFnCollector, PrimitiveDecoder, UnitDecoderFunction, ValuesDictionary, +use super::{ + deserialize_plain, AsDecoderFunction, ClosureDecoderFunction, DecoderFunction, DeltaCollector, + DeltaTranslator, IntoDecoderFunction, PlainDecoderFnCollector, PrimitiveDecoder, + UnitDecoderFunction, }; use crate::parquet::encoding::hybrid_rle::{self, DictionaryTranslator}; use crate::parquet::encoding::{byte_stream_split, delta_bitpacked, Encoding}; @@ -16,19 +16,20 @@ use crate::parquet::page::{split_buffer, DataPage, DictPage}; use crate::parquet::types::{decode, NativeType as ParquetNativeType}; use crate::read::deserialize::utils::array_chunks::ArrayChunks; use crate::read::deserialize::utils::{ - freeze_validity, BatchableCollector, Decoder, PageValidity, TranslatedHybridRle, + dict_indices_decoder, freeze_validity, BatchableCollector, Decoder, PageValidity, + TranslatedHybridRle, }; #[allow(clippy::large_enum_variant)] #[derive(Debug)] -pub(crate) enum StateTranslation<'a, P: ParquetNativeType, T: NativeType> { +pub(crate) enum StateTranslation<'a, P: ParquetNativeType> { Plain(ArrayChunks<'a, P>), - Dictionary(ValuesDictionary<'a, T>), + Dictionary(hybrid_rle::HybridRleDecoder<'a>), ByteStreamSplit(byte_stream_split::Decoder<'a>), DeltaBinaryPacked(delta_bitpacked::Decoder<'a>), } -impl<'a, P, T, D> utils::StateTranslation<'a, IntDecoder> for StateTranslation<'a, P, T> +impl<'a, P, T, D> utils::StateTranslation<'a, IntDecoder> for StateTranslation<'a, P> where T: NativeType, P: ParquetNativeType, @@ -44,8 +45,9 @@ where _page_validity: Option<&PageValidity<'a>>, ) -> ParquetResult { match (page.encoding(), dict) { - (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict)) => { - Ok(Self::Dictionary(ValuesDictionary::try_new(page, dict)?)) + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(_)) => { + let values = dict_indices_decoder(page)?; + Ok(Self::Dictionary(values)) }, (Encoding::Plain, _) => { let values = split_buffer(page)?.values; @@ -61,9 +63,9 @@ where }, (Encoding::DeltaBinaryPacked, _) => { let values = split_buffer(page)?.values; - Ok(Self::DeltaBinaryPacked(delta_bitpacked::Decoder::try_new( - values, - )?)) + Ok(Self::DeltaBinaryPacked( + delta_bitpacked::Decoder::try_new(values)?.0, + )) }, _ => Err(utils::not_implemented(page)), } @@ -74,7 +76,7 @@ where Self::Plain(v) => v.len(), Self::Dictionary(v) => v.len(), Self::ByteStreamSplit(v) => v.len(), - Self::DeltaBinaryPacked(v) => v.size_hint().0, + Self::DeltaBinaryPacked(v) => v.len(), } } @@ -84,10 +86,10 @@ where } match self { - Self::Plain(v) => _ = v.nth(n - 1), - Self::Dictionary(v) => v.values.skip_in_place(n)?, + Self::Plain(v) => v.skip_in_place(n), + Self::Dictionary(v) => v.skip_in_place(n)?, Self::ByteStreamSplit(v) => _ = v.iter_converted(|_| ()).nth(n - 1), - Self::DeltaBinaryPacked(v) => _ = v.nth(n - 1), + Self::DeltaBinaryPacked(v) => v.skip_in_place(n)?, } Ok(()) @@ -97,21 +99,25 @@ where &mut self, decoder: &mut IntDecoder, decoded: &mut as utils::Decoder>::DecodedState, + is_optional: bool, page_validity: &mut Option>, + dict: Option<&'a as utils::Decoder>::Dict>, additional: usize, ) -> ParquetResult<()> { match self { Self::Plain(page_values) => decoder.decode_plain_encoded( decoded, page_values, + is_optional, page_validity.as_mut(), additional, )?, - Self::Dictionary(page) => decoder.decode_dictionary_encoded( + Self::Dictionary(ref mut page) => decoder.decode_dictionary_encoded( decoded, - &mut page.values, + page, + is_optional, page_validity.as_mut(), - page.dict, + dict.unwrap(), additional, )?, Self::ByteStreamSplit(page_values) => { @@ -124,6 +130,10 @@ where .iter_converted(|v| decoder.0.decoder.decode(decode(v))) .take(additional), ); + + if is_optional { + validity.extend_constant(additional, true); + } }, Some(page_validity) => { utils::extend_from_decoder( @@ -140,23 +150,28 @@ where Self::DeltaBinaryPacked(page_values) => { let (values, validity) = decoded; + let mut gatherer = DeltaTranslator { + dfn: decoder.0.decoder, + _pd: std::marker::PhantomData, + }; + match page_validity { None => { - values.extend( - page_values - .by_ref() - .map(|x| decoder.0.decoder.decode(x.unwrap().as_())) - .take(additional), - ); + page_values.gather_n_into(values, additional, &mut gatherer)?; + + if is_optional { + validity.extend_constant(additional, true); + } }, Some(page_validity) => utils::extend_from_decoder( validity, page_validity, Some(additional), values, - &mut page_values - .by_ref() - .map(|x| decoder.0.decoder.decode(x.unwrap().as_())), + DeltaCollector { + decoder: page_values, + gatherer, + }, )?, } }, @@ -183,8 +198,8 @@ where D: DecoderFunction, { #[inline] - fn new(decoder: PrimitiveDecoder) -> Self { - Self(decoder) + fn new(decoder: D) -> Self { + Self(PrimitiveDecoder::new(decoder)) } } @@ -195,7 +210,7 @@ where UnitDecoderFunction: Default + DecoderFunction, { pub(crate) fn unit() -> Self { - Self::new(PrimitiveDecoder::unit()) + Self::new(UnitDecoderFunction::::default()) } } @@ -207,7 +222,7 @@ where AsDecoderFunction: Default + DecoderFunction, { pub(crate) fn cast_as() -> Self { - Self::new(PrimitiveDecoder::cast_as()) + Self::new(AsDecoderFunction::::default()) } } @@ -219,7 +234,7 @@ where IntoDecoderFunction: Default + DecoderFunction, { pub(crate) fn cast_into() -> Self { - Self::new(PrimitiveDecoder::cast_into()) + Self::new(IntoDecoderFunction::::default()) } } @@ -231,7 +246,7 @@ where F: Copy + Fn(P) -> T, { pub(crate) fn closure(f: F) -> Self { - Self::new(PrimitiveDecoder::closure(f)) + Self::new(ClosureDecoderFunction(f, std::marker::PhantomData)) } } @@ -242,23 +257,27 @@ where i64: num_traits::AsPrimitive

, D: DecoderFunction, { - type Translation<'a> = StateTranslation<'a, P, T>; + type Translation<'a> = StateTranslation<'a, P>; type Dict = Vec; type DecodedState = (Vec, MutableBitmap); type Output = PrimitiveArray; fn with_capacity(&self, capacity: usize) -> Self::DecodedState { - self.0.with_capacity(capacity) + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) } - fn deserialize_dict(&self, page: DictPage) -> Self::Dict { - self.0.deserialize_dict(page) + fn deserialize_dict(&self, page: DictPage) -> ParquetResult { + Ok(deserialize_plain::(&page.buffer, self.0.decoder)) } fn decode_plain_encoded<'a>( &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut as utils::StateTranslation<'a, Self>>::PlainDecoder, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, limit: usize, ) -> ParquetResult<()> { @@ -270,6 +289,10 @@ where _pd: Default::default(), } .push_n(values, limit)?; + + if is_optional { + validity.extend_constant(limit, true); + } }, Some(page_validity) => { let collector = PlainDecoderFnCollector { @@ -295,11 +318,20 @@ where &mut self, (values, validity): &mut Self::DecodedState, page_values: &mut hybrid_rle::HybridRleDecoder<'a>, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, dict: &Self::Dict, limit: usize, ) -> ParquetResult<()> { match page_validity { + None => { + let translator = DictionaryTranslator(dict); + page_values.translate_and_collect_n_into(values, limit, &translator)?; + + if is_optional { + validity.extend_constant(limit, true); + } + }, Some(page_validity) => { let translator = DictionaryTranslator(dict); let translated_hybridrle = TranslatedHybridRle::new(page_values, &translator); @@ -312,10 +344,6 @@ where translated_hybridrle, )?; }, - None => { - let translator = DictionaryTranslator(dict); - page_values.translate_and_collect_n_into(values, limit, &translator)?; - }, } Ok(()) @@ -323,12 +351,12 @@ where fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, _dict: Option, (values, validity): Self::DecodedState, ) -> ParquetResult { let validity = freeze_validity(validity); - Ok(PrimitiveArray::try_new(data_type, values.into(), validity).unwrap()) + Ok(PrimitiveArray::try_new(dtype, values.into(), validity).unwrap()) } } @@ -341,18 +369,18 @@ where { fn finalize_dict_array( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Self::Dict, keys: PrimitiveArray, ) -> ParquetResult> { - let value_type = match &data_type { + let value_type = match &dtype { ArrowDataType::Dictionary(_, value, _) => value.as_ref().clone(), _ => T::PRIMITIVE.into(), }; let dict = Box::new(PrimitiveArray::new(value_type, dict.into(), None)); - Ok(DictionaryArray::try_new(data_type, keys, dict).unwrap()) + Ok(DictionaryArray::try_new(dtype, keys, dict).unwrap()) } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs index c13dfa88bc3e..1a9d50a66d31 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/mod.rs @@ -1,5 +1,275 @@ -mod basic; +use arrow::types::NativeType; +use num_traits::AsPrimitive; + +use crate::parquet::types::{decode, NativeType as ParquetNativeType}; + +mod float; mod integer; -pub(crate) use basic::PrimitiveDecoder; +pub(crate) use float::FloatDecoder; pub(crate) use integer::IntDecoder; + +use super::utils::array_chunks::ArrayChunks; +use super::utils::BatchableCollector; +use super::ParquetResult; +use crate::parquet::encoding::delta_bitpacked::{self, DeltaGatherer}; + +#[derive(Debug)] +pub(crate) struct PrimitiveDecoder +where + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction, +{ + pub(crate) decoder: D, + _pd: std::marker::PhantomData<(P, T)>, +} + +impl PrimitiveDecoder +where + P: ParquetNativeType, + T: NativeType, + D: DecoderFunction, +{ + #[inline] + pub(crate) fn new(decoder: D) -> Self { + Self { + decoder, + _pd: std::marker::PhantomData, + } + } +} + +/// A function that defines how to decode from the +/// [`parquet::types::NativeType`][ParquetNativeType] to the [`arrow::types::NativeType`]. +/// +/// This should almost always be inlined. +pub(crate) trait DecoderFunction: Copy +where + T: NativeType, + P: ParquetNativeType, +{ + fn decode(self, x: P) -> T; +} + +#[derive(Default, Clone, Copy)] +pub(crate) struct UnitDecoderFunction(std::marker::PhantomData); +impl DecoderFunction for UnitDecoderFunction { + #[inline(always)] + fn decode(self, x: T) -> T { + x + } +} + +#[derive(Default, Clone, Copy)] +pub(crate) struct AsDecoderFunction(std::marker::PhantomData<(P, T)>); +macro_rules! as_decoder_impl { + ($($p:ty => $t:ty,)+) => { + $( + impl DecoderFunction<$p, $t> for AsDecoderFunction<$p, $t> { + #[inline(always)] + fn decode(self, x : $p) -> $t { + x as $t + } + } + )+ + }; +} + +as_decoder_impl![ + i32 => i8, + i32 => i16, + i32 => u8, + i32 => u16, + i32 => u32, + i64 => i32, + i64 => u32, + i64 => u64, +]; + +#[derive(Default, Clone, Copy)] +pub(crate) struct IntoDecoderFunction(std::marker::PhantomData<(P, T)>); +impl DecoderFunction for IntoDecoderFunction +where + P: ParquetNativeType + Into, + T: NativeType, +{ + #[inline(always)] + fn decode(self, x: P) -> T { + x.into() + } +} + +#[derive(Clone, Copy)] +pub(crate) struct ClosureDecoderFunction(F, std::marker::PhantomData<(P, T)>); +impl DecoderFunction for ClosureDecoderFunction +where + P: ParquetNativeType, + T: NativeType, + F: Copy + Fn(P) -> T, +{ + #[inline(always)] + fn decode(self, x: P) -> T { + (self.0)(x) + } +} + +pub(crate) struct PlainDecoderFnCollector<'a, 'b, P, T, D> +where + T: NativeType, + P: ParquetNativeType, + D: DecoderFunction, +{ + pub(crate) chunks: &'b mut ArrayChunks<'a, P>, + pub(crate) decoder: D, + pub(crate) _pd: std::marker::PhantomData, +} + +impl<'a, 'b, P, T, D: DecoderFunction> BatchableCollector<(), Vec> + for PlainDecoderFnCollector<'a, 'b, P, T, D> +where + T: NativeType, + P: ParquetNativeType, + D: DecoderFunction, +{ + fn reserve(target: &mut Vec, n: usize) { + target.reserve(n); + } + + fn push_n(&mut self, target: &mut Vec, n: usize) -> ParquetResult<()> { + let n = usize::min(self.chunks.len(), n); + let (items, remainder) = self.chunks.bytes.split_at(n); + let decoder = self.decoder; + target.extend( + items + .iter() + .map(|chunk| decoder.decode(P::from_le_bytes(*chunk))), + ); + self.chunks.bytes = remainder; + Ok(()) + } + + fn push_n_nulls(&mut self, target: &mut Vec, n: usize) -> ParquetResult<()> { + target.resize(target.len() + n, T::default()); + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.chunks.skip_in_place(n); + Ok(()) + } +} + +fn deserialize_plain(values: &[u8], decoder: D) -> Vec +where + T: NativeType, + P: ParquetNativeType, + D: DecoderFunction, +{ + values + .chunks_exact(std::mem::size_of::

()) + .map(decode) + .map(|v| decoder.decode(v)) + .collect::>() +} + +struct DeltaTranslator +where + T: NativeType, + P: ParquetNativeType, + i64: AsPrimitive

, + D: DecoderFunction, +{ + dfn: D, + _pd: std::marker::PhantomData<(P, T)>, +} + +struct DeltaCollector<'a, 'b, P, T, D> +where + T: NativeType, + P: ParquetNativeType, + i64: AsPrimitive

, + D: DecoderFunction, +{ + decoder: &'b mut delta_bitpacked::Decoder<'a>, + gatherer: DeltaTranslator, +} + +impl DeltaGatherer for DeltaTranslator +where + T: NativeType, + P: ParquetNativeType, + i64: AsPrimitive

, + D: DecoderFunction, +{ + type Target = Vec; + + fn target_len(&self, target: &Self::Target) -> usize { + target.len() + } + + fn target_reserve(&self, target: &mut Self::Target, n: usize) { + target.reserve(n); + } + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + target.push(self.dfn.decode(v.as_())); + Ok(()) + } + + fn gather_constant( + &mut self, + target: &mut Self::Target, + v: i64, + delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + target.extend((0..num_repeats).map(|i| self.dfn.decode((v + (i as i64) * delta).as_()))); + Ok(()) + } + + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + target.extend(slice.iter().copied().map(|v| self.dfn.decode(v.as_()))); + Ok(()) + } + + fn gather_chunk(&mut self, target: &mut Self::Target, chunk: &[i64; 64]) -> ParquetResult<()> { + target.extend(chunk.iter().copied().map(|v| self.dfn.decode(v.as_()))); + Ok(()) + } +} + +impl<'a, 'b, P, T, D> BatchableCollector<(), Vec> for DeltaCollector<'a, 'b, P, T, D> +where + T: NativeType, + P: ParquetNativeType, + i64: AsPrimitive

, + D: DecoderFunction, +{ + fn reserve(target: &mut Vec, n: usize) { + target.reserve(n); + } + + fn push_n(&mut self, target: &mut Vec, n: usize) -> ParquetResult<()> { + let start_length = target.len(); + let start_num_elems = self.decoder.len(); + + self.decoder.gather_n_into(target, n, &mut self.gatherer)?; + + let consumed_elements = usize::min(n, start_num_elems); + + debug_assert_eq!(self.decoder.len(), start_num_elems - consumed_elements); + debug_assert_eq!(target.len(), start_length + consumed_elements); + + Ok(()) + } + + fn push_n_nulls(&mut self, target: &mut Vec, n: usize) -> ParquetResult<()> { + target.resize(target.len() + n, T::default()); + Ok(()) + } + + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) + } +} diff --git a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs index 91a94d669c60..56912934e100 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/simple.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/simple.rs @@ -7,8 +7,7 @@ use polars_error::{polars_bail, PolarsResult}; use super::utils::filter::Filter; use super::{ - binary, boolean, dictionary, fixed_size_binary, null, primitive, BasicDecompressor, - ParquetResult, + boolean, dictionary, fixed_size_binary, null, primitive, BasicDecompressor, ParquetResult, }; use crate::parquet::error::ParquetError; use crate::parquet::schema::types::{ @@ -19,11 +18,11 @@ use crate::read::deserialize::binview::{self, BinViewDecoder}; use crate::read::deserialize::utils::PageDecoder; /// An iterator adapter that maps an iterator of Pages a boxed [`Array`] of [`ArrowDataType`] -/// `data_type` with a maximum of `num_rows` elements. +/// `dtype` with a maximum of `num_rows` elements. pub fn page_iter_to_array( pages: BasicDecompressor, type_: &PrimitiveType, - data_type: ArrowDataType, + dtype: ArrowDataType, filter: Option, ) -> PolarsResult> { use ArrowDataType::*; @@ -31,50 +30,50 @@ pub fn page_iter_to_array( let physical_type = &type_.physical_type; let logical_type = &type_.logical_type; - Ok(match (physical_type, data_type.to_logical_type()) { - (_, Null) => null::iter_to_arrays(pages, data_type, filter)?, + Ok(match (physical_type, dtype.to_logical_type()) { + (_, Null) => null::iter_to_arrays(pages, dtype, filter)?, (PhysicalType::Boolean, Boolean) => { - Box::new(PageDecoder::new(pages, data_type, boolean::BooleanDecoder)?.collect_n(filter)?) + Box::new(PageDecoder::new(pages, dtype, boolean::BooleanDecoder)?.collect_n(filter)?) }, (PhysicalType::Int32, UInt8) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Int32, UInt16) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Int32, UInt32) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Int64, UInt32) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Int32, Int8) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Int32, Int16) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Int32, Int32 | Date32 | Time32(_)) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::unit(), )? .collect_n(filter)?), @@ -84,15 +83,15 @@ pub fn page_iter_to_array( pages, physical_type, logical_type, - data_type, + dtype, filter, time_unit, ); }, (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => { - let size = FixedSizeBinaryArray::get_size(&data_type); + let size = FixedSizeBinaryArray::get_size(&dtype); - Box::new(PageDecoder::new(pages, data_type, fixed_size_binary::BinaryDecoder { size })? + Box::new(PageDecoder::new(pages, dtype, fixed_size_binary::BinaryDecoder { size })? .collect_n(filter)?) }, (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::YearMonth)) => { @@ -114,7 +113,7 @@ pub fn page_iter_to_array( let validity = array.validity().cloned(); Box::new(PrimitiveArray::::try_new( - data_type.clone(), + dtype.clone(), values.into(), validity, )?) @@ -138,20 +137,20 @@ pub fn page_iter_to_array( let validity = array.validity().cloned(); Box::new(PrimitiveArray::::try_new( - data_type.clone(), + dtype.clone(), values.into(), validity, )?) }, (PhysicalType::Int32, Decimal(_, _)) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_into(), )? .collect_n(filter)?), (PhysicalType::Int64, Decimal(_, _)) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_into(), )? .collect_n(filter)?), @@ -180,20 +179,20 @@ pub fn page_iter_to_array( let validity = array.validity().cloned(); Box::new(PrimitiveArray::::try_new( - data_type.clone(), + dtype.clone(), values.into(), validity, )?) }, (PhysicalType::Int32, Decimal256(_, _)) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::closure(|x: i32| i256(I256::new(x as i128))), )? .collect_n(filter)?), (PhysicalType::Int64, Decimal256(_, _)) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::closure(|x: i64| i256(I256::new(x as i128))), )? .collect_n(filter)?), @@ -217,7 +216,7 @@ pub fn page_iter_to_array( let validity = array.validity().cloned(); Box::new(PrimitiveArray::::try_new( - data_type.clone(), + dtype.clone(), values.into(), validity, )?) @@ -242,7 +241,7 @@ pub fn page_iter_to_array( let validity = array.validity().cloned(); Box::new(PrimitiveArray::::try_new( - data_type.clone(), + dtype.clone(), values.into(), validity, )?) @@ -254,52 +253,53 @@ pub fn page_iter_to_array( }, (PhysicalType::Int32, Date64) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::closure(|x: i32| i64::from(x) * 86400000), )? .collect_n(filter)?), (PhysicalType::Int64, Date64) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::unit(), )? .collect_n(filter)?), (PhysicalType::Int64, Int64 | Time64(_) | Duration(_)) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::unit(), )? .collect_n(filter)?), (PhysicalType::Int64, UInt64) => Box::new(PageDecoder::new( pages, - data_type, + dtype, primitive::IntDecoder::::cast_as(), )? .collect_n(filter)?), (PhysicalType::Float, Float32) => Box::new(PageDecoder::new( pages, - data_type, - primitive::PrimitiveDecoder::::unit(), + dtype, + primitive::FloatDecoder::::unit(), )? .collect_n(filter)?), (PhysicalType::Double, Float64) => Box::new(PageDecoder::new( pages, - data_type, - primitive::PrimitiveDecoder::::unit(), + dtype, + primitive::FloatDecoder::::unit(), )? .collect_n(filter)?), // Don't compile this code with `i32` as we don't use this in polars (PhysicalType::ByteArray, LargeBinary | LargeUtf8) => { - PageDecoder::new(pages, data_type, binary::BinaryDecoder::::default())? + PageDecoder::new(pages, dtype, binview::BinViewDecoder::default())? .collect_n(filter)? }, + (_, Binary | Utf8) => unreachable!(), (PhysicalType::ByteArray, BinaryView | Utf8View) => { - PageDecoder::new(pages, data_type, binview::BinViewDecoder::default())? + PageDecoder::new(pages, dtype, binview::BinViewDecoder::default())? .collect_n(filter)? }, (_, Dictionary(key_type, _, _)) => { return match_integer_type!(key_type, |$K| { - dict_read::<$K>(pages, physical_type, logical_type, data_type, filter).map(|v| Box::new(v) as Box<_>) + dict_read::<$K>(pages, physical_type, logical_type, dtype, filter).map(|v| Box::new(v) as Box<_>) }).map_err(Into::into) }, (from, to) => { @@ -383,7 +383,7 @@ fn timestamp( pages: BasicDecompressor, physical_type: &PhysicalType, logical_type: &Option, - data_type: ArrowDataType, + dtype: ArrowDataType, filter: Option, time_unit: TimeUnit, ) -> PolarsResult> { @@ -392,32 +392,32 @@ fn timestamp( TimeUnit::Nanosecond => Ok(Box::new( PageDecoder::new( pages, - data_type, - primitive::PrimitiveDecoder::closure(|x: [u32; 3]| int96_to_i64_ns(x)), + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_ns(x)), )? .collect_n(filter)?, )), TimeUnit::Microsecond => Ok(Box::new( PageDecoder::new( pages, - data_type, - primitive::PrimitiveDecoder::closure(|x: [u32; 3]| int96_to_i64_us(x)), + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_us(x)), )? .collect_n(filter)?, )), TimeUnit::Millisecond => Ok(Box::new( PageDecoder::new( pages, - data_type, - primitive::PrimitiveDecoder::closure(|x: [u32; 3]| int96_to_i64_ms(x)), + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_ms(x)), )? .collect_n(filter)?, )), TimeUnit::Second => Ok(Box::new( PageDecoder::new( pages, - data_type, - primitive::PrimitiveDecoder::closure(|x: [u32; 3]| int96_to_i64_s(x)), + dtype, + primitive::FloatDecoder::closure(|x: [u32; 3]| int96_to_i64_s(x)), )? .collect_n(filter)?, )), @@ -433,24 +433,16 @@ fn timestamp( let (factor, is_multiplier) = unify_timestamp_unit(logical_type, time_unit); Ok(match (factor, is_multiplier) { (1, _) => Box::new( - PageDecoder::new(pages, data_type, primitive::IntDecoder::::unit())? + PageDecoder::new(pages, dtype, primitive::IntDecoder::::unit())? .collect_n(filter)?, ), (a, true) => Box::new( - PageDecoder::new( - pages, - data_type, - primitive::IntDecoder::closure(|x: i64| x * a), - )? - .collect_n(filter)?, + PageDecoder::new(pages, dtype, primitive::IntDecoder::closure(|x: i64| x * a))? + .collect_n(filter)?, ), (a, false) => Box::new( - PageDecoder::new( - pages, - data_type, - primitive::IntDecoder::closure(|x: i64| x / a), - )? - .collect_n(filter)?, + PageDecoder::new(pages, dtype, primitive::IntDecoder::closure(|x: i64| x / a))? + .collect_n(filter)?, ), }) } @@ -459,7 +451,7 @@ fn timestamp_dict( pages: BasicDecompressor, physical_type: &PhysicalType, logical_type: &Option, - data_type: ArrowDataType, + dtype: ArrowDataType, filter: Option, time_unit: TimeUnit, ) -> ParquetResult> { @@ -473,7 +465,7 @@ fn timestamp_dict( (a, true) => PageDecoder::new( pages, ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), - dictionary::DictionaryDecoder::::new(primitive::PrimitiveDecoder::closure( + dictionary::DictionaryDecoder::::new(primitive::FloatDecoder::closure( |x: [u32; 3]| int96_to_i64_ns(x) * a, )), )? @@ -481,7 +473,7 @@ fn timestamp_dict( (a, false) => PageDecoder::new( pages, ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), - dictionary::DictionaryDecoder::::new(primitive::PrimitiveDecoder::closure( + dictionary::DictionaryDecoder::::new(primitive::FloatDecoder::closure( |x: [u32; 3]| int96_to_i64_ns(x) / a, )), )? @@ -493,18 +485,14 @@ fn timestamp_dict( match (factor, is_multiplier) { (a, true) => PageDecoder::new( pages, - data_type, - dictionary::DictionaryDecoder::new(primitive::PrimitiveDecoder::closure(|x: i64| { - x * a - })), + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::closure(|x: i64| x * a)), )? .collect_n(filter), (a, false) => PageDecoder::new( pages, - data_type, - dictionary::DictionaryDecoder::new(primitive::PrimitiveDecoder::closure(|x: i64| { - x / a - })), + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::closure(|x: i64| x / a)), )? .collect_n(filter), } @@ -514,147 +502,109 @@ fn dict_read( iter: BasicDecompressor, physical_type: &PhysicalType, logical_type: &Option, - data_type: ArrowDataType, + dtype: ArrowDataType, filter: Option, ) -> ParquetResult> { use ArrowDataType::*; - let values_data_type = if let Dictionary(_, v, _) = &data_type { + let values_dtype = if let Dictionary(_, v, _) = &dtype { v.as_ref() } else { panic!() }; - Ok( - match (physical_type, values_data_type.to_logical_type()) { - (PhysicalType::Int32, UInt8) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - )? - .collect_n(filter)?, - (PhysicalType::Int32, UInt16) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - )? - .collect_n(filter)?, - (PhysicalType::Int32, UInt32) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - )? - .collect_n(filter)?, - (PhysicalType::Int64, UInt64) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - )? - .collect_n(filter)?, - (PhysicalType::Int32, Int8) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - )? - .collect_n(filter)?, - (PhysicalType::Int32, Int16) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::cast_as(), - ), - )? - .collect_n(filter)?, - ( - PhysicalType::Int32, - Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth), - ) => { - PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::unit(), - ), - )? - .collect_n(filter)? - }, - - (PhysicalType::Int64, Timestamp(time_unit, _)) => { - let time_unit = *time_unit; - return timestamp_dict::( - iter, - physical_type, - logical_type, - data_type, - filter, - time_unit, - ); - }, - - (PhysicalType::Int64, Int64 | Date64 | Time64(_) | Duration(_)) => { - PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::unit(), - ), - )? - .collect_n(filter)? - }, - (PhysicalType::Float, Float32) => { - PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::unit(), - ), - )? - .collect_n(filter)? - }, - (PhysicalType::Double, Float64) => { - PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new( - primitive::PrimitiveDecoder::::unit(), - ), - )? - .collect_n(filter)? - }, - (PhysicalType::ByteArray, LargeUtf8 | LargeBinary) => PageDecoder::new( - iter, - data_type, - dictionary::DictionaryDecoder::new(binary::BinaryDecoder::::default()), - )? - .collect_n(filter)?, - (PhysicalType::ByteArray, Utf8View | BinaryView) => PageDecoder::new( + Ok(match (physical_type, values_dtype.to_logical_type()) { + (PhysicalType::Int32, UInt8) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::cast_as()), + )? + .collect_n(filter)?, + (PhysicalType::Int32, UInt16) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::cast_as()), + )? + .collect_n(filter)?, + (PhysicalType::Int32, UInt32) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::cast_as()), + )? + .collect_n(filter)?, + (PhysicalType::Int64, UInt64) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::cast_as()), + )? + .collect_n(filter)?, + (PhysicalType::Int32, Int8) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::cast_as()), + )? + .collect_n(filter)?, + (PhysicalType::Int32, Int16) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::cast_as()), + )? + .collect_n(filter)?, + (PhysicalType::Int32, Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth)) => { + PageDecoder::new( iter, - data_type, - dictionary::DictionaryDecoder::new(BinViewDecoder::default()), + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::unit()), )? - .collect_n(filter)?, - (PhysicalType::FixedLenByteArray(size), FixedSizeBinary(_)) => PageDecoder::new( + .collect_n(filter)? + }, + + (PhysicalType::Int64, Timestamp(time_unit, _)) => { + let time_unit = *time_unit; + return timestamp_dict::( iter, - data_type, - dictionary::DictionaryDecoder::new(fixed_size_binary::BinaryDecoder { - size: *size, - }), - )? - .collect_n(filter)?, - other => { - return Err(ParquetError::FeatureNotSupported(format!( - "Reading dictionaries of type {other:?}" - ))); - }, + physical_type, + logical_type, + dtype, + filter, + time_unit, + ); + }, + + (PhysicalType::Int64, Int64 | Date64 | Time64(_) | Duration(_)) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::unit()), + )? + .collect_n(filter)?, + (PhysicalType::Float, Float32) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::unit()), + )? + .collect_n(filter)?, + (PhysicalType::Double, Float64) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(primitive::FloatDecoder::::unit()), + )? + .collect_n(filter)?, + (_, LargeUtf8 | LargeBinary | Utf8 | Binary) => unreachable!(), + (PhysicalType::ByteArray, Utf8View | BinaryView) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(BinViewDecoder::default()), + )? + .collect_n(filter)?, + (PhysicalType::FixedLenByteArray(size), FixedSizeBinary(_)) => PageDecoder::new( + iter, + dtype, + dictionary::DictionaryDecoder::new(fixed_size_binary::BinaryDecoder { size: *size }), + )? + .collect_n(filter)?, + other => { + return Err(ParquetError::FeatureNotSupported(format!( + "Reading dictionaries of type {other:?}" + ))); }, - ) + }) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs index f95be359631d..330ad77a7c44 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/array_chunks.rs @@ -24,6 +24,11 @@ impl<'a, P: ParquetNativeType> ArrayChunks<'a, P> { Some(Self { bytes }) } + + pub(crate) fn skip_in_place(&mut self, n: usize) { + let n = usize::min(self.bytes.len(), n); + self.bytes = &self.bytes[n..]; + } } impl<'a, P: ParquetNativeType> Iterator for ArrayChunks<'a, P> { @@ -36,13 +41,6 @@ impl<'a, P: ParquetNativeType> Iterator for ArrayChunks<'a, P> { Some(item) } - #[inline(always)] - fn nth(&mut self, n: usize) -> Option { - let item = self.bytes.get(n)?; - self.bytes = &self.bytes[n + 1..]; - Some(item) - } - #[inline(always)] fn size_hint(&self) -> (usize, Option) { (self.bytes.len(), Some(self.bytes.len())) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs index b7a9c6645701..03e641634467 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/filter.rs @@ -22,6 +22,13 @@ impl Filter { Filter::Mask(mask) } + pub fn do_include_at(&self, at: usize) -> bool { + match self { + Filter::Range(range) => range.contains(&at), + Filter::Mask(bitmap) => bitmap.get_bit(at), + } + } + pub(crate) fn num_rows(&self) -> usize { match self { Filter::Range(range) => range.len(), diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs index ff4cf72023d2..dba00fc97930 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils/mod.rs @@ -1,16 +1,12 @@ pub(crate) mod array_chunks; pub(crate) mod filter; -use arrow::array::{ - BinaryArray, DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray, View, -}; +use arrow::array::{DictionaryArray, DictionaryKey, MutableBinaryViewArray, PrimitiveArray, View}; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::ArrowDataType; use arrow::pushable::Pushable; -use arrow::types::Offset; use self::filter::Filter; -use super::binary::utils::Binary; use super::BasicDecompressor; use crate::parquet::encoding::hybrid_rle::gatherer::{ HybridRleGatherer, ZeroCount, ZeroCountGatherer, @@ -22,6 +18,8 @@ use crate::parquet::schema::Repetition; #[derive(Debug)] pub(crate) struct State<'a, D: Decoder> { + pub(crate) dict: Option<&'a D::Dict>, + pub(crate) is_optional: bool, pub(crate) page_validity: Option>, pub(crate) translation: D::Translation<'a>, } @@ -44,7 +42,9 @@ pub(crate) trait StateTranslation<'a, D: Decoder>: Sized { &mut self, decoder: &mut D, decoded: &mut D::DecodedState, + is_optional: bool, page_validity: &mut Option>, + dict: Option<&'a D::Dict>, additional: usize, ) -> ParquetResult<()>; } @@ -54,13 +54,25 @@ impl<'a, D: Decoder> State<'a, D> { let is_optional = page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; - let page_validity = is_optional + let mut page_validity = is_optional .then(|| page_validity_decoder(page)) .transpose()?; + // Make the page_validity None if there are no nulls in the page + let null_count = page + .null_count() + .map(Ok) + .or_else(|| page_validity.as_ref().map(hybrid_rle_count_zeros)) + .transpose()?; + if null_count == Some(0) { + page_validity = None; + } + let translation = D::Translation::new(decoder, page, dict, page_validity.as_ref())?; Ok(Self { + dict, + is_optional, page_validity, translation, }) @@ -74,7 +86,11 @@ impl<'a, D: Decoder> State<'a, D> { let translation = D::Translation::new(decoder, page, dict, None)?; Ok(Self { + dict, translation, + + // Nested values may be optional, but all that is handled elsewhere. + is_optional: false, page_validity: None, }) } @@ -112,10 +128,17 @@ impl<'a, D: Decoder> State<'a, D> { match filter { None => { let num_rows = self.len(); + + if num_rows == 0 { + return Ok(()); + } + self.translation.extend_from_state( decoder, decoded, + self.is_optional, &mut self.page_validity, + self.dict, num_rows, ) }, @@ -126,12 +149,18 @@ impl<'a, D: Decoder> State<'a, D> { self.skip_in_place(start)?; debug_assert!(end - start <= self.len()); - self.translation.extend_from_state( - decoder, - decoded, - &mut self.page_validity, - end - start, - )?; + + if end - start > 0 { + self.translation.extend_from_state( + decoder, + decoded, + self.is_optional, + &mut self.page_validity, + self.dict, + end - start, + )?; + } + Ok(()) }, Filter::Mask(bitmap) => { @@ -142,12 +171,17 @@ impl<'a, D: Decoder> State<'a, D> { let prev_state_len = self.len(); let num_ones = iter.take_leading_ones(); - self.translation.extend_from_state( - decoder, - decoded, - &mut self.page_validity, - num_ones, - )?; + + if num_ones > 0 { + self.translation.extend_from_state( + decoder, + decoded, + self.is_optional, + &mut self.page_validity, + self.dict, + num_ones, + )?; + } if iter.num_remaining() == 0 || self.len() == 0 { break; @@ -171,11 +205,9 @@ impl<'a, D: Decoder> State<'a, D> { pub fn not_implemented(page: &DataPage) -> ParquetError { let is_optional = page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; - let is_filtered = page.selected_rows().is_some(); let required = if is_optional { "optional" } else { "required" }; - let is_filtered = if is_filtered { ", index-filtered" } else { "" }; ParquetError::not_supported(format!( - "Decoding {:?} \"{:?}\"-encoded {required}{is_filtered} parquet pages not yet supported", + "Decoding {:?} \"{:?}\"-encoded {required} parquet pages not yet supported", page.descriptor.primitive_type.physical_type, page.encoding(), )) @@ -185,14 +217,15 @@ pub trait BatchableCollector { fn reserve(target: &mut T, n: usize); fn push_n(&mut self, target: &mut T, n: usize) -> ParquetResult<()>; fn push_n_nulls(&mut self, target: &mut T, n: usize) -> ParquetResult<()>; + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()>; } /// This batches sequential collect operations to try and prevent unnecessary buffering and /// `Iterator::next` polling. #[must_use] pub struct BatchedCollector<'a, I, T, C: BatchableCollector> { - num_waiting_valids: usize, - num_waiting_invalids: usize, + pub(crate) num_waiting_valids: usize, + pub(crate) num_waiting_invalids: usize, target: &'a mut T, collector: C, @@ -243,6 +276,24 @@ impl<'a, I, T, C: BatchableCollector> BatchedCollector<'a, I, T, C> { self.num_waiting_invalids += n; } + #[inline] + pub fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + if self.num_waiting_valids > 0 { + self.collector + .push_n(self.target, self.num_waiting_valids)?; + self.num_waiting_valids = 0; + } + if self.num_waiting_invalids > 0 { + self.collector + .push_n_nulls(self.target, self.num_waiting_invalids)?; + self.num_waiting_invalids = 0; + } + + self.collector.skip_in_place(n)?; + + Ok(()) + } + #[inline] pub fn finalize(mut self) -> ParquetResult<()> { self.collector @@ -403,6 +454,11 @@ where target.resize(target.len() + n, O::default()); Ok(()) } + + #[inline] + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) + } } pub struct GatheredHybridRle<'a, 'b, 'c, O, G> @@ -453,32 +509,10 @@ where .gather_repeated(target, self.null_value.clone(), n)?; Ok(()) } -} -impl<'a, 'b, 'c, O, Out, G> BatchableCollector> - for GatheredHybridRle<'a, 'b, 'c, Out, G> -where - O: Offset, - Out: Clone, - G: HybridRleGatherer>, -{ #[inline] - fn reserve(target: &mut Binary, n: usize) { - target.offsets.reserve(n); - target.values.reserve(n); - } - - #[inline] - fn push_n(&mut self, target: &mut Binary, n: usize) -> ParquetResult<()> { - self.decoder.gather_n_into(target, n, self.gatherer)?; - Ok(()) - } - - #[inline] - fn push_n_nulls(&mut self, target: &mut Binary, n: usize) -> ParquetResult<()> { - self.gatherer - .gather_repeated(target, self.null_value.clone(), n)?; - Ok(()) + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) } } @@ -494,8 +528,11 @@ where #[inline] fn push_n(&mut self, target: &mut MutableBinaryViewArray<[u8]>, n: usize) -> ParquetResult<()> { - self.decoder - .translate_and_collect_n_into(target.views_mut(), n, self.translator)?; + self.decoder.translate_and_collect_n_into( + unsafe { target.views_mut() }, + n, + self.translator, + )?; if let Some(validity) = target.validity() { validity.extend_constant(n, true); @@ -513,6 +550,11 @@ where target.extend_null(n); Ok(()) } + + #[inline] + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + self.decoder.skip_in_place(n) + } } impl, I: Iterator> BatchableCollector for I { @@ -532,6 +574,14 @@ impl, I: Iterator> BatchableCollector for I { target.extend_null_constant(n); Ok(()) } + + #[inline] + fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + if n > 0 { + _ = self.nth(n - 1); + } + Ok(()) + } } /// An item with a known size @@ -555,12 +605,21 @@ pub(super) trait Decoder: Sized { fn with_capacity(&self, capacity: usize) -> Self::DecodedState; /// Deserializes a [`DictPage`] into [`Self::Dict`]. - fn deserialize_dict(&self, page: DictPage) -> Self::Dict; + fn deserialize_dict(&self, page: DictPage) -> ParquetResult; + + fn apply_dictionary( + &mut self, + _decoded: &mut Self::DecodedState, + _dict: &Self::Dict, + ) -> ParquetResult<()> { + Ok(()) + } fn decode_plain_encoded<'a>( &mut self, decoded: &mut Self::DecodedState, page_values: &mut as StateTranslation<'a, Self>>::PlainDecoder, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, limit: usize, ) -> ParquetResult<()>; @@ -568,6 +627,7 @@ pub(super) trait Decoder: Sized { &mut self, decoded: &mut Self::DecodedState, page_values: &mut HybridRleDecoder<'a>, + is_optional: bool, page_validity: Option<&mut PageValidity<'a>>, dict: &Self::Dict, limit: usize, @@ -575,7 +635,7 @@ pub(super) trait Decoder: Sized { fn finalize( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Option, decoded: Self::DecodedState, ) -> ParquetResult; @@ -616,7 +676,7 @@ pub(crate) trait NestedDecoder: Decoder { pub trait DictDecodable: Decoder { fn finalize_dict_array( &self, - data_type: ArrowDataType, + dtype: ArrowDataType, dict: Self::Dict, keys: PrimitiveArray, ) -> ParquetResult>; @@ -624,7 +684,7 @@ pub trait DictDecodable: Decoder { pub struct PageDecoder { pub iter: BasicDecompressor, - pub data_type: ArrowDataType, + pub dtype: ArrowDataType, pub dict: Option, pub decoder: D, } @@ -632,15 +692,15 @@ pub struct PageDecoder { impl PageDecoder { pub fn new( mut iter: BasicDecompressor, - data_type: ArrowDataType, + dtype: ArrowDataType, decoder: D, ) -> ParquetResult { let dict_page = iter.read_dict_page()?; - let dict = dict_page.map(|d| decoder.deserialize_dict(d)); + let dict = dict_page.map(|d| decoder.deserialize_dict(d)).transpose()?; Ok(Self { iter, - data_type, + dtype, dict, decoder, }) @@ -651,17 +711,27 @@ impl PageDecoder { let mut target = self.decoder.with_capacity(num_rows_remaining); + if let Some(dict) = self.dict.as_ref() { + self.decoder.apply_dictionary(&mut target, dict)?; + } + while num_rows_remaining > 0 { let Some(page) = self.iter.next() else { - return self.decoder.finalize(self.data_type, self.dict, target); + break; }; let page = page?; - let mut state = State::new(&self.decoder, &page, self.dict.as_ref())?; - let state_len = state.len(); - let state_filter; - (state_filter, filter) = Filter::opt_split_at(&filter, state_len); + (state_filter, filter) = Filter::opt_split_at(&filter, page.num_values()); + + // Skip the whole page if we don't need any rows from it + if state_filter.as_ref().is_some_and(|f| f.num_rows() == 0) { + continue; + } + + let page = page.decompress(&mut self.iter)?; + + let mut state = State::new(&self.decoder, &page, self.dict.as_ref())?; let start_length = target.len(); state.extend_from_state(&mut self.decoder, &mut target, state_filter)?; @@ -675,7 +745,7 @@ impl PageDecoder { self.iter.reuse_page_buffer(page); } - self.decoder.finalize(self.data_type, self.dict, target) + self.decoder.finalize(self.dtype, self.dict, target) } } @@ -695,43 +765,6 @@ pub(super) fn dict_indices_decoder(page: &DataPage) -> ParquetResult, - dict: &BinaryArray, -) -> Vec { - // We create a dictionary of views here, so that the views only have be calculated - // once and are then just a lookup. We also only push the dictionary buffer when we - // see the first View that cannot be inlined. - // - // @TODO: Maybe we can do something smarter here by only pushing the items that are larger than - // 12 bytes. Maybe, we say if the num_inlined < dict.len() / 2 then push the whole buffer. - // Otherwise, only push the non-inlinable items. - - let mut buffer_idx = None; - dict.values_iter() - .enumerate() - .map(|(i, value)| { - if value.len() <= View::MAX_INLINE_SIZE as usize { - View::new_inline(value) - } else { - let (offset_start, offset_end) = dict.offsets().start_end(i); - debug_assert_eq!(value.len(), offset_end - offset_start); - - let buffer_idx = - buffer_idx.get_or_insert_with(|| values.push_buffer(dict.values().clone())); - - debug_assert!(offset_start <= u32::MAX as usize); - View::new_from_bytes(value, *buffer_idx, offset_start as u32) - } - }) - .collect() -} - /// Freeze a [`MutableBitmap`] into a `Option`. /// /// This will turn the several instances where `None` (representing "all valid") suffices. @@ -748,3 +781,13 @@ pub fn freeze_validity(validity: MutableBitmap) -> Option { Some(validity) } + +pub(crate) fn hybrid_rle_count_zeros( + decoder: &hybrid_rle::HybridRleDecoder<'_>, +) -> ParquetResult { + let mut count = ZeroCount::default(); + decoder + .clone() + .gather_into(&mut count, &ZeroCountGatherer)?; + Ok(count.num_zero) +} diff --git a/crates/polars-parquet/src/arrow/read/indexes/binary.rs b/crates/polars-parquet/src/arrow/read/indexes/binary.rs deleted file mode 100644 index b6e017644746..000000000000 --- a/crates/polars-parquet/src/arrow/read/indexes/binary.rs +++ /dev/null @@ -1,44 +0,0 @@ -use arrow::array::{Array, BinaryArray, PrimitiveArray, Utf8Array}; -use arrow::datatypes::{ArrowDataType, PhysicalType}; -use arrow::trusted_len::TrustedLen; -use polars_error::{to_compute_err, PolarsResult}; - -use super::ColumnPageStatistics; -use crate::parquet::indexes::PageIndex; - -pub fn deserialize( - indexes: &[PageIndex>], - data_type: &ArrowDataType, -) -> PolarsResult { - Ok(ColumnPageStatistics { - min: deserialize_binary_iter(indexes.iter().map(|index| index.min.as_ref()), data_type)?, - max: deserialize_binary_iter(indexes.iter().map(|index| index.max.as_ref()), data_type)?, - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - }) -} - -fn deserialize_binary_iter<'a, I: TrustedLen>>>( - iter: I, - data_type: &ArrowDataType, -) -> PolarsResult> { - match data_type.to_physical_type() { - PhysicalType::LargeBinary => Ok(Box::new(BinaryArray::::from_iter(iter))), - PhysicalType::Utf8 => { - let iter = iter.map(|x| x.map(|x| std::str::from_utf8(x)).transpose()); - Ok(Box::new( - Utf8Array::::try_from_trusted_len_iter(iter).map_err(to_compute_err)?, - )) - }, - PhysicalType::LargeUtf8 => { - let iter = iter.map(|x| x.map(|x| std::str::from_utf8(x)).transpose()); - Ok(Box::new( - Utf8Array::::try_from_trusted_len_iter(iter).map_err(to_compute_err)?, - )) - }, - _ => Ok(Box::new(BinaryArray::::from_iter(iter))), - } -} diff --git a/crates/polars-parquet/src/arrow/read/indexes/boolean.rs b/crates/polars-parquet/src/arrow/read/indexes/boolean.rs deleted file mode 100644 index b6414e24a621..000000000000 --- a/crates/polars-parquet/src/arrow/read/indexes/boolean.rs +++ /dev/null @@ -1,20 +0,0 @@ -use arrow::array::{BooleanArray, PrimitiveArray}; - -use super::ColumnPageStatistics; -use crate::parquet::indexes::PageIndex; - -pub fn deserialize(indexes: &[PageIndex]) -> ColumnPageStatistics { - ColumnPageStatistics { - min: Box::new(BooleanArray::from_trusted_len_iter( - indexes.iter().map(|index| index.min), - )), - max: Box::new(BooleanArray::from_trusted_len_iter( - indexes.iter().map(|index| index.max), - )), - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - } -} diff --git a/crates/polars-parquet/src/arrow/read/indexes/fixed_len_binary.rs b/crates/polars-parquet/src/arrow/read/indexes/fixed_len_binary.rs deleted file mode 100644 index 5b2785b22b06..000000000000 --- a/crates/polars-parquet/src/arrow/read/indexes/fixed_len_binary.rs +++ /dev/null @@ -1,70 +0,0 @@ -use arrow::array::{Array, FixedSizeBinaryArray, MutableFixedSizeBinaryArray, PrimitiveArray}; -use arrow::datatypes::{ArrowDataType, PhysicalType, PrimitiveType}; -use arrow::trusted_len::TrustedLen; -use arrow::types::{i256, NativeType}; - -use super::ColumnPageStatistics; -use crate::parquet::indexes::PageIndex; - -pub fn deserialize( - indexes: &[PageIndex>], - data_type: ArrowDataType, -) -> ColumnPageStatistics { - ColumnPageStatistics { - min: deserialize_binary_iter( - indexes.iter().map(|index| index.min.as_ref()), - data_type.clone(), - ), - max: deserialize_binary_iter(indexes.iter().map(|index| index.max.as_ref()), data_type), - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - } -} - -fn deserialize_binary_iter<'a, I: TrustedLen>>>( - iter: I, - data_type: ArrowDataType, -) -> Box { - match data_type.to_physical_type() { - PhysicalType::Primitive(PrimitiveType::Int128) => { - Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| { - v.map(|x| { - // Copy the fixed-size byte value to the start of a 16 byte stack - // allocated buffer, then use an arithmetic right shift to fill in - // MSBs, which accounts for leading 1's in negative (two's complement) - // values. - let n = x.len(); - let mut bytes = [0u8; 16]; - bytes[..n].copy_from_slice(x); - i128::from_be_bytes(bytes) >> (8 * (16 - n)) - }) - }))) - }, - PhysicalType::Primitive(PrimitiveType::Int256) => { - Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| { - v.map(|x| { - let n = x.len(); - let mut bytes = [0u8; 32]; - bytes[..n].copy_from_slice(x); - i256::from_be_bytes(bytes) - }) - }))) - }, - _ => { - let mut a = MutableFixedSizeBinaryArray::try_new( - data_type, - Vec::with_capacity(iter.size_hint().0), - None, - ) - .unwrap(); - for item in iter { - a.push(item); - } - let a: FixedSizeBinaryArray = a.into(); - Box::new(a) - }, - } -} diff --git a/crates/polars-parquet/src/arrow/read/indexes/mod.rs b/crates/polars-parquet/src/arrow/read/indexes/mod.rs deleted file mode 100644 index 9cf465c64206..000000000000 --- a/crates/polars-parquet/src/arrow/read/indexes/mod.rs +++ /dev/null @@ -1,377 +0,0 @@ -//! API to perform page-level filtering (also known as indexes) -use crate::parquet::error::ParquetError; -use crate::parquet::indexes::{ - select_pages, BooleanIndex, ByteIndex, FixedLenByteIndex, Index as ParquetIndex, NativeIndex, - PageLocation, -}; -use crate::parquet::metadata::{ColumnChunkMetaData, RowGroupMetaData}; -use crate::parquet::read::{read_columns_indexes as _read_columns_indexes, read_pages_locations}; -use crate::parquet::schema::types::PhysicalType as ParquetPhysicalType; - -mod binary; -mod boolean; -mod fixed_len_binary; -mod primitive; - -use std::collections::VecDeque; -use std::io::{Read, Seek}; - -use arrow::array::{Array, UInt64Array}; -use arrow::datatypes::{ArrowDataType, Field, PhysicalType, PrimitiveType}; -use polars_error::{polars_bail, PolarsResult}; - -use super::get_field_pages; -pub use crate::parquet::indexes::{FilteredPage, Interval}; - -/// Page statistics of an Arrow field. -#[derive(Debug, PartialEq)] -pub enum FieldPageStatistics { - /// Variant used for fields with a single parquet column (e.g. primitives, dictionaries, list) - Single(ColumnPageStatistics), - /// Variant used for fields with multiple parquet columns (e.g. Struct, Map) - Multiple(Vec), -} - -impl From for FieldPageStatistics { - fn from(column: ColumnPageStatistics) -> Self { - Self::Single(column) - } -} - -/// [`ColumnPageStatistics`] contains the minimum, maximum, and null_count -/// of each page of a parquet column, as an [`Array`]. -/// This struct has the following invariants: -/// * `min`, `max` and `null_count` have the same length (equal to the number of pages in the column) -/// * `min`, `max` and `null_count` are guaranteed to be non-null -/// * `min` and `max` have the same logical type -#[derive(Debug, PartialEq)] -pub struct ColumnPageStatistics { - /// The minimum values in the pages - pub min: Box, - /// The maximum values in the pages - pub max: Box, - /// The number of null values in the pages. - pub null_count: UInt64Array, -} - -/// Given a sequence of [`ParquetIndex`] representing the page indexes of each column in the -/// parquet file, returns the page-level statistics as a [`FieldPageStatistics`]. -/// -/// This function maps timestamps, decimal types, etc. accordingly. -/// # Implementation -/// This function is CPU-bounded `O(P)` where `P` is the total number of pages on all columns. -/// # Error -/// This function errors iff the value is not deserializable to arrow (e.g. invalid utf-8) -fn deserialize( - indexes: &mut VecDeque<&dyn ParquetIndex>, - data_type: ArrowDataType, -) -> PolarsResult { - match data_type.to_physical_type() { - PhysicalType::Boolean => { - let index = indexes - .pop_front() - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - Ok(boolean::deserialize(&index.indexes).into()) - }, - PhysicalType::Primitive(PrimitiveType::Int128) => { - let index = indexes.pop_front().unwrap(); - match index.physical_type() { - ParquetPhysicalType::Int32 => { - let index = index.as_any().downcast_ref::>().unwrap(); - Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) - }, - crate::parquet::schema::types::PhysicalType::Int64 => { - let index = index.as_any().downcast_ref::>().unwrap(); - Ok( - primitive::deserialize_i64( - &index.indexes, - &index.primitive_type, - data_type, - ) - .into(), - ) - }, - crate::parquet::schema::types::PhysicalType::FixedLenByteArray(_) => { - let index = index.as_any().downcast_ref::().unwrap(); - Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) - }, - other => polars_bail!(nyi = "Deserialize {other:?} to arrow's int64"), - } - }, - PhysicalType::Primitive(PrimitiveType::Int256) => { - let index = indexes.pop_front().unwrap(); - match index.physical_type() { - ParquetPhysicalType::Int32 => { - let index = index.as_any().downcast_ref::>().unwrap(); - Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) - }, - crate::parquet::schema::types::PhysicalType::Int64 => { - let index = index.as_any().downcast_ref::>().unwrap(); - Ok( - primitive::deserialize_i64( - &index.indexes, - &index.primitive_type, - data_type, - ) - .into(), - ) - }, - crate::parquet::schema::types::PhysicalType::FixedLenByteArray(_) => { - let index = index.as_any().downcast_ref::().unwrap(); - Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) - }, - other => polars_bail!(nyi = "Deserialize {other:?} to arrow's int64"), - } - }, - PhysicalType::Primitive(PrimitiveType::UInt8) - | PhysicalType::Primitive(PrimitiveType::UInt16) - | PhysicalType::Primitive(PrimitiveType::UInt32) - | PhysicalType::Primitive(PrimitiveType::Int32) => { - let index = indexes - .pop_front() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap(); - Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) - }, - PhysicalType::Primitive(PrimitiveType::UInt64) - | PhysicalType::Primitive(PrimitiveType::Int64) => { - let index = indexes.pop_front().unwrap(); - match index.physical_type() { - ParquetPhysicalType::Int64 => { - let index = index.as_any().downcast_ref::>().unwrap(); - Ok( - primitive::deserialize_i64( - &index.indexes, - &index.primitive_type, - data_type, - ) - .into(), - ) - }, - crate::parquet::schema::types::PhysicalType::Int96 => { - let index = index - .as_any() - .downcast_ref::>() - .unwrap(); - Ok(primitive::deserialize_i96(&index.indexes, data_type).into()) - }, - other => polars_bail!(nyi = "Deserialize {other:?} to arrow's int64"), - } - }, - PhysicalType::Primitive(PrimitiveType::Float32) => { - let index = indexes - .pop_front() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap(); - Ok(primitive::deserialize_id(&index.indexes, data_type).into()) - }, - PhysicalType::Primitive(PrimitiveType::Float64) => { - let index = indexes - .pop_front() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap(); - Ok(primitive::deserialize_id(&index.indexes, data_type).into()) - }, - PhysicalType::Binary - | PhysicalType::LargeBinary - | PhysicalType::Utf8 - | PhysicalType::LargeUtf8 - | PhysicalType::Utf8View - | PhysicalType::BinaryView => { - let index = indexes - .pop_front() - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - binary::deserialize(&index.indexes, &data_type).map(|x| x.into()) - }, - PhysicalType::FixedSizeBinary => { - let index = indexes - .pop_front() - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) - }, - PhysicalType::Dictionary(_) => { - if let ArrowDataType::Dictionary(_, inner, _) = data_type.to_logical_type() { - deserialize(indexes, (**inner).clone()) - } else { - unreachable!() - } - }, - PhysicalType::List => { - if let ArrowDataType::List(inner) = data_type.to_logical_type() { - deserialize(indexes, inner.data_type.clone()) - } else { - unreachable!() - } - }, - PhysicalType::LargeList => { - if let ArrowDataType::LargeList(inner) = data_type.to_logical_type() { - deserialize(indexes, inner.data_type.clone()) - } else { - unreachable!() - } - }, - PhysicalType::Map => { - if let ArrowDataType::Map(inner, _) = data_type.to_logical_type() { - deserialize(indexes, inner.data_type.clone()) - } else { - unreachable!() - } - }, - PhysicalType::Struct => { - let children_fields = - if let ArrowDataType::Struct(children) = data_type.to_logical_type() { - children - } else { - unreachable!() - }; - let children = children_fields - .iter() - .map(|child| deserialize(indexes, child.data_type.clone())) - .collect::>>()?; - - Ok(FieldPageStatistics::Multiple(children)) - }, - - other => polars_bail!(nyi = "Deserialize into arrow's {other:?} page index"), - } -} - -/// Checks whether the row group have page index information (page statistics) -pub fn has_indexes(row_group: &RowGroupMetaData) -> bool { - row_group - .columns() - .iter() - .all(|chunk| chunk.column_chunk().column_index_offset.is_some()) -} - -/// Reads the column indexes from the reader assuming a valid set of derived Arrow fields -/// for all parquet the columns in the file. -/// -/// It returns one [`FieldPageStatistics`] per field in `fields` -/// -/// This function is expected to be used to filter out parquet pages. -/// -/// # Implementation -/// This function is IO-bounded and calls `reader.read_exact` exactly once. -/// # Error -/// Errors iff the indexes can't be read or their deserialization to arrow is incorrect (e.g. invalid utf-8) -pub fn read_columns_indexes( - reader: &mut R, - chunks: &[ColumnChunkMetaData], - fields: &[Field], -) -> PolarsResult> { - let indexes = _read_columns_indexes(reader, chunks)?; - - fields - .iter() - .map(|field| { - let indexes = get_field_pages(chunks, &indexes, &field.name); - let mut indexes = indexes.into_iter().map(|boxed| boxed.as_ref()).collect(); - - deserialize(&mut indexes, field.data_type.clone()) - }) - .collect() -} - -/// Returns the set of (row) intervals of the pages. -pub fn compute_page_row_intervals( - locations: &[PageLocation], - num_rows: usize, -) -> Result, ParquetError> { - if locations.is_empty() { - return Ok(vec![]); - }; - - let last = (|| { - let start: usize = locations.last().unwrap().first_row_index.try_into()?; - let length = num_rows - start; - Result::<_, ParquetError>::Ok(Interval::new(start, length)) - })(); - - let pages_lengths = locations - .windows(2) - .map(|x| { - let start = usize::try_from(x[0].first_row_index)?; - let length = usize::try_from(x[1].first_row_index - x[0].first_row_index)?; - Ok(Interval::new(start, length)) - }) - .chain(std::iter::once(last)); - pages_lengths.collect() -} - -/// Reads all page locations and index locations (IO-bounded) and uses `predicate` to compute -/// the set of [`FilteredPage`] that fulfill the predicate. -/// -/// The non-trivial argument of this function is `predicate`, that controls which pages are selected. -/// Its signature contains 2 arguments: -/// * 0th argument (indexes): contains one [`ColumnPageStatistics`] (page statistics) per field. -/// Use it to evaluate the predicate against -/// * 1th argument (intervals): contains one [`Vec>`] (row positions) per field. -/// For each field, the outermost vector corresponds to each parquet column: -/// a primitive field contains 1 column, a struct field with 2 primitive fields contain 2 columns. -/// The inner `Vec` contains one [`Interval`] per page: its length equals the length of [`ColumnPageStatistics`]. -/// -/// It returns a single [`Vec`] denoting the set of intervals that the predicate selects (over all columns). -/// -/// This returns one item per `field`. For each field, there is one item per column (for non-nested types it returns one column) -/// and finally [`Vec`], that corresponds to the set of selected pages. -pub fn read_filtered_pages< - R: Read + Seek, - F: Fn(&[FieldPageStatistics], &[Vec>]) -> Vec, ->( - reader: &mut R, - row_group: &RowGroupMetaData, - fields: &[Field], - predicate: F, - //is_intersection: bool, -) -> PolarsResult>>> { - let num_rows = row_group.num_rows(); - - // one vec per column - let locations = read_pages_locations(reader, row_group.columns())?; - // one Vec> per field (non-nested contain a single entry on the first column) - let locations = fields - .iter() - .map(|field| get_field_pages(row_group.columns(), &locations, &field.name)) - .collect::>(); - - // one ColumnPageStatistics per field - let indexes = read_columns_indexes(reader, row_group.columns(), fields)?; - - let intervals = locations - .iter() - .map(|locations| { - locations - .iter() - .map(|locations| Ok(compute_page_row_intervals(locations, num_rows)?)) - .collect::>>() - }) - .collect::>>()?; - - let intervals = predicate(&indexes, &intervals); - - locations - .into_iter() - .map(|locations| { - locations - .into_iter() - .map(|locations| Ok(select_pages(&intervals, locations, num_rows)?)) - .collect::>>() - }) - .collect() -} diff --git a/crates/polars-parquet/src/arrow/read/indexes/primitive.rs b/crates/polars-parquet/src/arrow/read/indexes/primitive.rs deleted file mode 100644 index dfd72bc9c54e..000000000000 --- a/crates/polars-parquet/src/arrow/read/indexes/primitive.rs +++ /dev/null @@ -1,227 +0,0 @@ -use arrow::array::{Array, MutablePrimitiveArray, PrimitiveArray}; -use arrow::datatypes::{ArrowDataType, TimeUnit}; -use arrow::trusted_len::TrustedLen; -use arrow::types::{i256, NativeType}; -use ethnum::I256; - -use super::ColumnPageStatistics; -use crate::parquet::indexes::PageIndex; -use crate::parquet::schema::types::{ - PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, -}; -use crate::parquet::types::int96_to_i64_ns; - -#[inline] -fn deserialize_int32>>( - iter: I, - data_type: ArrowDataType, -) -> Box { - use ArrowDataType::*; - match data_type.to_logical_type() { - UInt8 => Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u8))) - .to(data_type), - ) as _, - UInt16 => Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u16))) - .to(data_type), - ), - UInt32 => Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u32))) - .to(data_type), - ), - Decimal(_, _) => Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as i128))) - .to(data_type), - ), - Decimal256(_, _) => Box::new( - PrimitiveArray::::from_trusted_len_iter( - iter.map(|x| x.map(|x| i256(I256::new(x.into())))), - ) - .to(data_type), - ) as _, - _ => Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)), - } -} - -#[inline] -fn timestamp( - array: &mut MutablePrimitiveArray, - time_unit: TimeUnit, - logical_type: Option, -) { - let unit = if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { - unit - } else { - return; - }; - - match (unit, time_unit) { - (ParquetTimeUnit::Milliseconds, TimeUnit::Second) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000), - (ParquetTimeUnit::Microseconds, TimeUnit::Second) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000_000), - (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000_000_000), - - (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) => {}, - (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000), - (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000_000), - - (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x *= 1_000), - (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) => {}, - (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000), - - (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x *= 1_000_000), - (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => array - .values_mut_slice() - .iter_mut() - .for_each(|x| *x /= 1_000), - (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => {}, - } -} - -#[inline] -fn deserialize_int64>>( - iter: I, - primitive_type: &PrimitiveType, - data_type: ArrowDataType, -) -> Box { - use ArrowDataType::*; - match data_type.to_logical_type() { - UInt64 => Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u64))) - .to(data_type), - ) as _, - Decimal(_, _) => Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as i128))) - .to(data_type), - ) as _, - Decimal256(_, _) => Box::new( - PrimitiveArray::::from_trusted_len_iter( - iter.map(|x| x.map(|x| i256(I256::new(x.into())))), - ) - .to(data_type), - ) as _, - Timestamp(time_unit, _) => { - let mut array = - MutablePrimitiveArray::::from_trusted_len_iter(iter).to(data_type.clone()); - - timestamp(&mut array, *time_unit, primitive_type.logical_type); - - let array: PrimitiveArray = array.into(); - - Box::new(array) - }, - _ => Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)), - } -} - -#[inline] -fn deserialize_int96>>( - iter: I, - data_type: ArrowDataType, -) -> Box { - Box::new( - PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(int96_to_i64_ns))) - .to(data_type), - ) -} - -#[inline] -fn deserialize_id_s>>( - iter: I, - data_type: ArrowDataType, -) -> Box { - Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)) -} - -pub fn deserialize_i32( - indexes: &[PageIndex], - data_type: ArrowDataType, -) -> ColumnPageStatistics { - ColumnPageStatistics { - min: deserialize_int32(indexes.iter().map(|index| index.min), data_type.clone()), - max: deserialize_int32(indexes.iter().map(|index| index.max), data_type), - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - } -} - -pub fn deserialize_i64( - indexes: &[PageIndex], - primitive_type: &PrimitiveType, - data_type: ArrowDataType, -) -> ColumnPageStatistics { - ColumnPageStatistics { - min: deserialize_int64( - indexes.iter().map(|index| index.min), - primitive_type, - data_type.clone(), - ), - max: deserialize_int64( - indexes.iter().map(|index| index.max), - primitive_type, - data_type, - ), - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - } -} - -pub fn deserialize_i96( - indexes: &[PageIndex<[u32; 3]>], - data_type: ArrowDataType, -) -> ColumnPageStatistics { - ColumnPageStatistics { - min: deserialize_int96(indexes.iter().map(|index| index.min), data_type.clone()), - max: deserialize_int96(indexes.iter().map(|index| index.max), data_type), - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - } -} - -pub fn deserialize_id( - indexes: &[PageIndex], - data_type: ArrowDataType, -) -> ColumnPageStatistics { - ColumnPageStatistics { - min: deserialize_id_s(indexes.iter().map(|index| index.min), data_type.clone()), - max: deserialize_id_s(indexes.iter().map(|index| index.max), data_type), - null_count: PrimitiveArray::from_trusted_len_iter( - indexes - .iter() - .map(|index| index.null_count.map(|x| x as u64)), - ), - } -} diff --git a/crates/polars-parquet/src/arrow/read/mod.rs b/crates/polars-parquet/src/arrow/read/mod.rs index fff6987f4f1a..8af4fb3f67bb 100644 --- a/crates/polars-parquet/src/arrow/read/mod.rs +++ b/crates/polars-parquet/src/arrow/read/mod.rs @@ -2,25 +2,19 @@ #![allow(clippy::type_complexity)] mod deserialize; -mod file; -pub mod indexes; -mod row_group; pub mod schema; pub mod statistics; use std::io::{Read, Seek}; -use arrow::array::Array; use arrow::types::{i256, NativeType}; pub use deserialize::{ column_iter_to_arrays, create_list, create_map, get_page_iterator, init_nested, n_columns, Filter, InitNested, NestedState, }; -pub use file::{FileReader, RowGroupReader}; #[cfg(feature = "async")] use futures::{AsyncRead, AsyncSeek}; use polars_error::PolarsResult; -pub use row_group::*; pub use schema::{infer_schema, FileMetaData}; use crate::parquet::error::ParquetResult; @@ -30,12 +24,11 @@ pub use crate::parquet::read::{get_page_stream, read_metadata_async as _read_met pub use crate::parquet::{ error::ParquetError, fallible_streaming_iterator, - metadata::{ColumnChunkMetaData, ColumnDescriptor, RowGroupMetaData}, + metadata::{ColumnChunkMetadata, ColumnDescriptor, RowGroupMetaData}, page::{CompressedDataPage, DataPageHeader, Page}, read::{ - decompress, get_column_iterator, read_columns_indexes as _read_columns_indexes, - read_metadata as _read_metadata, read_pages_locations, BasicDecompressor, - MutStreamingIterator, PageFilter, PageReader, ReadColumnIterator, State, + decompress, get_column_iterator, read_metadata as _read_metadata, BasicDecompressor, + MutStreamingIterator, PageReader, ReadColumnIterator, State, }, schema::types::{ GroupLogicalType, ParquetType, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, @@ -45,8 +38,20 @@ pub use crate::parquet::{ FallibleStreamingIterator, }; -/// Type def for a sharable, boxed dyn [`Iterator`] of arrays -pub type ArrayIter<'a> = Box>> + Send + Sync + 'a>; +/// Returns all [`ColumnChunkMetadata`] associated to `field_name`. +/// For non-nested parquet types, this returns a single column +pub fn get_field_pages<'a, T>( + columns: &'a [ColumnChunkMetadata], + items: &'a [T], + field_name: &str, +) -> Vec<&'a T> { + columns + .iter() + .zip(items) + .filter(|(metadata, _)| metadata.descriptor().path_in_schema[0].as_str() == field_name) + .map(|(_, item)| item) + .collect() +} /// Reads parquets' metadata synchronously. pub fn read_metadata(reader: &mut R) -> PolarsResult { diff --git a/crates/polars-parquet/src/arrow/read/schema/convert.rs b/crates/polars-parquet/src/arrow/read/schema/convert.rs index 2089e261188f..e79139109845 100644 --- a/crates/polars-parquet/src/arrow/read/schema/convert.rs +++ b/crates/polars-parquet/src/arrow/read/schema/convert.rs @@ -1,5 +1,6 @@ //! This module has entry points, [`parquet_to_arrow_schema`] and the more configurable [`parquet_to_arrow_schema_with_options`]. -use arrow::datatypes::{ArrowDataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{ArrowDataType, ArrowSchema, Field, IntervalUnit, TimeUnit}; +use polars_utils::pl_str::PlSmallStr; use crate::arrow::read::schema::SchemaInferenceOptions; use crate::parquet::schema::types::{ @@ -10,7 +11,7 @@ use crate::parquet::schema::Repetition; /// Converts [`ParquetType`]s to a [`Field`], ignoring parquet fields that do not contain /// any physical column. -pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> Vec { +pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> ArrowSchema { parquet_to_arrow_schema_with_options(fields, &None) } @@ -18,11 +19,12 @@ pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> Vec { pub fn parquet_to_arrow_schema_with_options( fields: &[ParquetType], options: &Option, -) -> Vec { +) -> ArrowSchema { fields .iter() .filter_map(|f| to_field(f, options.as_ref().unwrap_or(&Default::default()))) - .collect::>() + .map(|x| (x.name.clone(), x)) + .collect() } fn from_int32( @@ -91,7 +93,7 @@ fn from_int64( let timezone = if is_adjusted_to_utc { // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md // A TIMESTAMP with isAdjustedToUTC=true is defined as [...] elapsed since the Unix epoch - Some("+00:00".to_string()) + Some(PlSmallStr::from_static("+00:00")) } else { // PARQUET: // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md @@ -222,7 +224,7 @@ fn to_primitive_type( if primitive_type.field_info.repetition == Repetition::Repeated { ArrowDataType::LargeList(Box::new(Field::new( - &primitive_type.field_info.name, + primitive_type.field_info.name.clone(), base_type, is_nullable(&primitive_type.field_info), ))) @@ -285,7 +287,7 @@ fn to_group_type( debug_assert!(!fields.is_empty()); if field_info.repetition == Repetition::Repeated { Some(ArrowDataType::LargeList(Box::new(Field::new( - &field_info.name, + field_info.name.clone(), to_struct(fields, options)?, is_nullable(field_info), )))) @@ -308,8 +310,8 @@ pub(crate) fn is_nullable(field_info: &FieldInfo) -> bool { /// i.e. if it is a column-less group type. fn to_field(type_: &ParquetType, options: &SchemaInferenceOptions) -> Option { Some(Field::new( - &type_.get_field_info().name, - to_data_type(type_, options)?, + type_.get_field_info().name.clone(), + to_dtype(type_, options)?, is_nullable(type_.get_field_info()), )) } @@ -328,13 +330,17 @@ fn to_list( let item_type = match item { ParquetType::PrimitiveType(primitive) => Some(to_primitive_type_inner(primitive, options)), ParquetType::GroupType { fields, .. } => { - if fields.len() == 1 - && item.name() != "array" - && item.name() != format!("{parent_name}_tuple") - { + if fields.len() == 1 && item.name() != "array" && { + // item.name() != format!("{parent_name}_tuple") + let cmp = [parent_name, "_tuple"]; + let len_1 = parent_name.len(); + let len = len_1 + "_tuple".len(); + + item.name().len() != len || [&item.name()[..len_1], &item.name()[len_1..]] != cmp + } { // extract the repetition field let nested_item = fields.first().unwrap(); - to_data_type(nested_item, options) + to_dtype(nested_item, options) } else { to_struct(fields, options) } @@ -348,15 +354,15 @@ fn to_list( let (list_item_name, item_is_optional) = match item { ParquetType::GroupType { field_info, fields, .. - } if field_info.name == "list" && fields.len() == 1 => { + } if field_info.name.as_str() == "list" && fields.len() == 1 => { let field = fields.first().unwrap(); ( - &field.get_field_info().name, + field.get_field_info().name.clone(), field.get_field_info().repetition == Repetition::Optional, ) }, _ => ( - &item.get_field_info().name, + item.get_field_info().name.clone(), item.get_field_info().repetition == Repetition::Optional, ), }; @@ -377,7 +383,7 @@ fn to_list( /// /// If this schema is a group type and none of its children is reserved in the /// conversion, the result is Ok(None). -pub(crate) fn to_data_type( +pub(crate) fn to_dtype( type_: &ParquetType, options: &SchemaInferenceOptions, ) -> Option { @@ -397,7 +403,7 @@ pub(crate) fn to_data_type( logical_type, converted_type, fields, - &field_info.name, + field_info.name.as_str(), options, ) } @@ -430,21 +436,22 @@ mod tests { } "; let expected = &[ - Field::new("boolean", ArrowDataType::Boolean, false), - Field::new("int8", ArrowDataType::Int8, false), - Field::new("int16", ArrowDataType::Int16, false), - Field::new("uint8", ArrowDataType::UInt8, false), - Field::new("uint16", ArrowDataType::UInt16, false), - Field::new("int32", ArrowDataType::Int32, false), - Field::new("int64", ArrowDataType::Int64, false), - Field::new("double", ArrowDataType::Float64, true), - Field::new("float", ArrowDataType::Float32, true), - Field::new("string", ArrowDataType::Utf8View, true), - Field::new("string_2", ArrowDataType::Utf8View, true), + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + Field::new("int16".into(), ArrowDataType::Int16, false), + Field::new("uint8".into(), ArrowDataType::UInt8, false), + Field::new("uint16".into(), ArrowDataType::UInt16, false), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("double".into(), ArrowDataType::Float64, true), + Field::new("float".into(), ArrowDataType::Float32, true), + Field::new("string".into(), ArrowDataType::Utf8View, true), + Field::new("string_2".into(), ArrowDataType::Utf8View, true), ]; let parquet_schema = SchemaDescriptor::try_from_message(message)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(fields, expected); Ok(()) @@ -459,12 +466,17 @@ mod tests { } "; let expected = vec![ - Field::new("binary", ArrowDataType::BinaryView, false), - Field::new("fixed_binary", ArrowDataType::FixedSizeBinary(20), false), + Field::new("binary".into(), ArrowDataType::BinaryView, false), + Field::new( + "fixed_binary".into(), + ArrowDataType::FixedSizeBinary(20), + false, + ), ]; let parquet_schema = SchemaDescriptor::try_from_message(message)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(fields, expected); Ok(()) @@ -479,12 +491,13 @@ mod tests { } "; let expected = &[ - Field::new("boolean", ArrowDataType::Boolean, false), - Field::new("int8", ArrowDataType::Int8, false), + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), ]; let parquet_schema = SchemaDescriptor::try_from_message(message)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(fields, expected); Ok(()) @@ -554,9 +567,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list", + "my_list".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Utf8, true, ))), @@ -572,9 +585,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list", + "my_list".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Utf8, false, ))), @@ -596,13 +609,17 @@ mod tests { // } { let arrow_inner_list = ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Int32, false, ))); arrow_fields.push(Field::new( - "array_of_arrays", - ArrowDataType::LargeList(Box::new(Field::new("element", arrow_inner_list, false))), + "array_of_arrays".into(), + ArrowDataType::LargeList(Box::new(Field::new( + PlSmallStr::from_static("element"), + arrow_inner_list, + false, + ))), true, )); } @@ -615,9 +632,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list", + "my_list".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Utf8, false, ))), @@ -631,9 +648,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list", + "my_list".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Int32, false, ))), @@ -650,12 +667,16 @@ mod tests { // } { let arrow_struct = ArrowDataType::Struct(vec![ - Field::new("str", ArrowDataType::Utf8, false), - Field::new("num", ArrowDataType::Int32, false), + Field::new("str".into(), ArrowDataType::Utf8, false), + Field::new("num".into(), ArrowDataType::Int32, false), ]); arrow_fields.push(Field::new( - "my_list", - ArrowDataType::LargeList(Box::new(Field::new("element", arrow_struct, false))), + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "element".into(), + arrow_struct, + false, + ))), true, )); } @@ -669,10 +690,10 @@ mod tests { // Special case: group is named array { let arrow_struct = - ArrowDataType::Struct(vec![Field::new("str", ArrowDataType::Utf8, false)]); + ArrowDataType::Struct(vec![Field::new("str".into(), ArrowDataType::Utf8, false)]); arrow_fields.push(Field::new( - "my_list", - ArrowDataType::LargeList(Box::new(Field::new("array", arrow_struct, false))), + "my_list".into(), + ArrowDataType::LargeList(Box::new(Field::new("array".into(), arrow_struct, false))), true, )); } @@ -686,11 +707,11 @@ mod tests { // Special case: group named ends in _tuple { let arrow_struct = - ArrowDataType::Struct(vec![Field::new("str", ArrowDataType::Utf8, false)]); + ArrowDataType::Struct(vec![Field::new("str".into(), ArrowDataType::Utf8, false)]); arrow_fields.push(Field::new( - "my_list", + "my_list".into(), ArrowDataType::LargeList(Box::new(Field::new( - "my_list_tuple", + "my_list_tuple".into(), arrow_struct, false, ))), @@ -702,14 +723,19 @@ mod tests { // repeated value_type name { arrow_fields.push(Field::new( - "name", - ArrowDataType::LargeList(Box::new(Field::new("name", ArrowDataType::Int32, false))), + "name".into(), + ArrowDataType::LargeList(Box::new(Field::new( + "name".into(), + ArrowDataType::Int32, + false, + ))), false, )); } let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -732,17 +758,17 @@ mod tests { { let struct_fields = vec![ - Field::new("event_name", ArrowDataType::Utf8View, false), + Field::new("event_name".into(), ArrowDataType::Utf8View, false), Field::new( - "event_time", + "event_time".into(), ArrowDataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), false, ), ]; arrow_fields.push(Field::new( - "events", + "events".into(), ArrowDataType::LargeList(Box::new(Field::new( - "array", + "array".into(), ArrowDataType::Struct(struct_fields), false, ))), @@ -752,6 +778,7 @@ mod tests { let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -789,9 +816,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list1", + "my_list1".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Utf8View, true, ))), @@ -807,9 +834,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list2", + "my_list2".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Utf8View, false, ))), @@ -825,9 +852,9 @@ mod tests { // } { arrow_fields.push(Field::new( - "my_list3", + "my_list3".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Utf8View, false, ))), @@ -837,6 +864,7 @@ mod tests { let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -847,13 +875,14 @@ mod tests { let mut arrow_fields = Vec::new(); { let group1_fields = vec![ - Field::new("leaf1", ArrowDataType::Boolean, false), - Field::new("leaf2", ArrowDataType::Int32, false), + Field::new("leaf1".into(), ArrowDataType::Boolean, false), + Field::new("leaf2".into(), ArrowDataType::Int32, false), ]; - let group1_struct = Field::new("group1", ArrowDataType::Struct(group1_fields), false); + let group1_struct = + Field::new("group1".into(), ArrowDataType::Struct(group1_fields), false); arrow_fields.push(group1_struct); - let leaf3_field = Field::new("leaf3", ArrowDataType::Int64, false); + let leaf3_field = Field::new("leaf3".into(), ArrowDataType::Int64, false); arrow_fields.push(leaf3_field); } @@ -869,6 +898,7 @@ mod tests { let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -879,24 +909,28 @@ mod tests { fn test_repeated_nested_schema() -> PolarsResult<()> { let mut arrow_fields = Vec::new(); { - arrow_fields.push(Field::new("leaf1", ArrowDataType::Int32, true)); + arrow_fields.push(Field::new("leaf1".into(), ArrowDataType::Int32, true)); let inner_group_list = Field::new( - "innerGroup", + "innerGroup".into(), ArrowDataType::LargeList(Box::new(Field::new( - "innerGroup", - ArrowDataType::Struct(vec![Field::new("leaf3", ArrowDataType::Int32, true)]), + "innerGroup".into(), + ArrowDataType::Struct(vec![Field::new( + "leaf3".into(), + ArrowDataType::Int32, + true, + )]), false, ))), false, ); let outer_group_list = Field::new( - "outerGroup", + "outerGroup".into(), ArrowDataType::LargeList(Box::new(Field::new( - "outerGroup", + "outerGroup".into(), ArrowDataType::Struct(vec![ - Field::new("leaf2", ArrowDataType::Int32, true), + Field::new("leaf2".into(), ArrowDataType::Int32, true), inner_group_list, ]), false, @@ -920,6 +954,7 @@ mod tests { let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -951,60 +986,61 @@ mod tests { } "; let arrow_fields = vec![ - Field::new("boolean", ArrowDataType::Boolean, false), - Field::new("int8", ArrowDataType::Int8, false), - Field::new("uint8", ArrowDataType::UInt8, false), - Field::new("int16", ArrowDataType::Int16, false), - Field::new("uint16", ArrowDataType::UInt16, false), - Field::new("int32", ArrowDataType::Int32, false), - Field::new("int64", ArrowDataType::Int64, false), - Field::new("double", ArrowDataType::Float64, true), - Field::new("float", ArrowDataType::Float32, true), - Field::new("string", ArrowDataType::Utf8, true), + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + Field::new("uint8".into(), ArrowDataType::UInt8, false), + Field::new("int16".into(), ArrowDataType::Int16, false), + Field::new("uint16".into(), ArrowDataType::UInt16, false), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("double".into(), ArrowDataType::Float64, true), + Field::new("float".into(), ArrowDataType::Float32, true), + Field::new("string".into(), ArrowDataType::Utf8, true), Field::new( - "bools", + "bools".into(), ArrowDataType::LargeList(Box::new(Field::new( - "bools", + "bools".into(), ArrowDataType::Boolean, false, ))), false, ), - Field::new("date", ArrowDataType::Date32, true), + Field::new("date".into(), ArrowDataType::Date32, true), Field::new( - "time_milli", + "time_milli".into(), ArrowDataType::Time32(TimeUnit::Millisecond), true, ), Field::new( - "time_micro", + "time_micro".into(), ArrowDataType::Time64(TimeUnit::Microsecond), true, ), Field::new( - "time_nano", + "time_nano".into(), ArrowDataType::Time64(TimeUnit::Nanosecond), true, ), Field::new( - "ts_milli", + "ts_milli".into(), ArrowDataType::Timestamp(TimeUnit::Millisecond, None), true, ), Field::new( - "ts_micro", + "ts_micro".into(), ArrowDataType::Timestamp(TimeUnit::Microsecond, None), false, ), Field::new( - "ts_nano", - ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".to_string())), + "ts_nano".into(), + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), false, ), ]; let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -1051,62 +1087,62 @@ mod tests { "; let arrow_fields = vec![ - Field::new("boolean", ArrowDataType::Boolean, false), - Field::new("int8", ArrowDataType::Int8, false), - Field::new("int16", ArrowDataType::Int16, false), - Field::new("int32", ArrowDataType::Int32, false), - Field::new("int64", ArrowDataType::Int64, false), - Field::new("double", ArrowDataType::Float64, true), - Field::new("float", ArrowDataType::Float32, true), - Field::new("string", ArrowDataType::Utf8View, true), + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("int8".into(), ArrowDataType::Int8, false), + Field::new("int16".into(), ArrowDataType::Int16, false), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("double".into(), ArrowDataType::Float64, true), + Field::new("float".into(), ArrowDataType::Float32, true), + Field::new("string".into(), ArrowDataType::Utf8View, true), Field::new( - "bools", + "bools".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Boolean, true, ))), true, ), Field::new( - "bools_non_null", + "bools_non_null".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Boolean, false, ))), false, ), - Field::new("date", ArrowDataType::Date32, true), + Field::new("date".into(), ArrowDataType::Date32, true), Field::new( - "time_milli", + "time_milli".into(), ArrowDataType::Time32(TimeUnit::Millisecond), true, ), Field::new( - "time_micro", + "time_micro".into(), ArrowDataType::Time64(TimeUnit::Microsecond), true, ), Field::new( - "ts_milli", + "ts_milli".into(), ArrowDataType::Timestamp(TimeUnit::Millisecond, None), true, ), Field::new( - "ts_micro", + "ts_micro".into(), ArrowDataType::Timestamp(TimeUnit::Microsecond, None), false, ), Field::new( - "struct", + "struct".into(), ArrowDataType::Struct(vec![ - Field::new("bools", ArrowDataType::Boolean, false), - Field::new("uint32", ArrowDataType::UInt32, false), + Field::new("bools".into(), ArrowDataType::Boolean, false), + Field::new("uint32".into(), ArrowDataType::UInt32, false), Field::new( - "int32", + "int32".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), ArrowDataType::Int32, true, ))), @@ -1115,11 +1151,12 @@ mod tests { ]), false, ), - Field::new("dictionary_strings", ArrowDataType::Utf8View, false), + Field::new("dictionary_strings".into(), ArrowDataType::Utf8View, false), ]; let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; let fields = parquet_to_arrow_schema(parquet_schema.fields()); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); Ok(()) @@ -1148,20 +1185,20 @@ mod tests { "; let coerced_to = ArrowDataType::Timestamp(tu, None); let arrow_fields = vec![ - Field::new("int96_field", coerced_to.clone(), false), + Field::new("int96_field".into(), coerced_to.clone(), false), Field::new( - "int96_list", + "int96_list".into(), ArrowDataType::LargeList(Box::new(Field::new( - "element", + "element".into(), coerced_to.clone(), true, ))), true, ), Field::new( - "int96_struct", + "int96_struct".into(), ArrowDataType::Struct(vec![Field::new( - "int96_field", + "int96_field".into(), coerced_to.clone(), false, )]), @@ -1176,6 +1213,7 @@ mod tests { int96_coerce_to_timeunit: tu, }), ); + let fields = fields.iter_values().cloned().collect::>(); assert_eq!(arrow_fields, fields); } Ok(()) diff --git a/crates/polars-parquet/src/arrow/read/schema/metadata.rs b/crates/polars-parquet/src/arrow/read/schema/metadata.rs index 5b3dd20725cb..915936c81296 100644 --- a/crates/polars-parquet/src/arrow/read/schema/metadata.rs +++ b/crates/polars-parquet/src/arrow/read/schema/metadata.rs @@ -3,6 +3,7 @@ use arrow::io::ipc::read::deserialize_schema; use base64::engine::general_purpose; use base64::Engine as _; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use super::super::super::ARROW_SCHEMA_META_KEY; pub use crate::parquet::metadata::KeyValue; @@ -17,44 +18,40 @@ pub fn read_schema_from_metadata(metadata: &mut Metadata) -> PolarsResult

, { + let is_optional = options.is_optional(); + if is_optional { // append the non-null values let validity = array.validity(); @@ -33,10 +35,10 @@ where let null_count = validity.unset_bits(); if null_count > 0 { - let values = array.values().as_slice(); let mut iter = validity.iter(); + let values = array.values().as_slice(); - buffer.reserve(std::mem::size_of::

() * (array.len() - null_count)); + buffer.reserve(std::mem::size_of::() * (array.len() - null_count)); let mut offset = 0; let mut remaining_valid = array.len() - null_count; @@ -72,7 +74,7 @@ where pub(crate) fn encode_delta( array: &PrimitiveArray, - is_optional: bool, + options: EncodeNullability, mut buffer: Vec, ) -> Vec where @@ -81,6 +83,8 @@ where T: num_traits::AsPrimitive

, P: num_traits::AsPrimitive, { + let is_optional = options.is_optional(); + if is_optional { // append the non-null values let iterator = array.non_null_values_iter().map(|x| { @@ -89,7 +93,7 @@ where integer }); let iterator = ExactSizedIter::new(iterator, array.len() - array.null_count()); - encode(iterator, &mut buffer) + encode(iterator, &mut buffer, 1) } else { // append all values let iterator = array.values().iter().map(|x| { @@ -97,7 +101,7 @@ where let integer: i64 = parquet_native.as_(); integer }); - encode(iterator, &mut buffer) + encode(iterator, &mut buffer, 1) } buffer } @@ -135,7 +139,7 @@ where .map(Page::Data) } -pub fn array_to_page, bool, Vec) -> Vec>( +pub fn array_to_page, EncodeNullability, Vec) -> Vec>( array: &PrimitiveArray, options: WriteOptions, type_: PrimitiveType, @@ -149,6 +153,7 @@ where T: num_traits::AsPrimitive

, { let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); let validity = array.validity(); @@ -163,7 +168,7 @@ where let definition_levels_byte_length = buffer.len(); - let buffer = encode(array, is_optional, buffer); + let buffer = encode(array, encode_options, buffer); let statistics = if options.has_statistics() { Some(build_statistics(array, type_.clone(), &options.statistics).serialize()) diff --git a/crates/polars-parquet/src/arrow/write/primitive/nested.rs b/crates/polars-parquet/src/arrow/write/primitive/nested.rs index 918afa6a4dc6..b5391263025e 100644 --- a/crates/polars-parquet/src/arrow/write/primitive/nested.rs +++ b/crates/polars-parquet/src/arrow/write/primitive/nested.rs @@ -10,6 +10,7 @@ use crate::parquet::encoding::Encoding; use crate::parquet::page::DataPage; use crate::parquet::schema::types::PrimitiveType; use crate::parquet::types::NativeType; +use crate::write::EncodeNullability; pub fn array_to_page( array: &PrimitiveArray, @@ -23,13 +24,14 @@ where T: num_traits::AsPrimitive, { let is_optional = is_nullable(&type_.field_info); + let encode_options = EncodeNullability::new(is_optional); let mut buffer = vec![]; let (repetition_levels_byte_length, definition_levels_byte_length) = nested::write_rep_and_def(options.version, nested, &mut buffer)?; - let buffer = encode_plain(array, is_optional, buffer); + let buffer = encode_plain(array, encode_options, buffer); let statistics = if options.has_statistics() { Some(build_statistics(array, type_.clone(), &options.statistics).serialize()) diff --git a/crates/polars-parquet/src/arrow/write/row_group.rs b/crates/polars-parquet/src/arrow/write/row_group.rs index 28928a2dab08..397b79ed46ee 100644 --- a/crates/polars-parquet/src/arrow/write/row_group.rs +++ b/crates/polars-parquet/src/arrow/write/row_group.rs @@ -82,7 +82,7 @@ impl + 'static, I: Iterator>, ) -> PolarsResult { - if encodings.len() != schema.fields.len() { + if encodings.len() != schema.len() { polars_bail!(InvalidOperation: "The number of encodings must equal the number of fields".to_string(), ) diff --git a/crates/polars-parquet/src/arrow/write/schema.rs b/crates/polars-parquet/src/arrow/write/schema.rs index 047291770180..1403a7f4eeec 100644 --- a/crates/polars-parquet/src/arrow/write/schema.rs +++ b/crates/polars-parquet/src/arrow/write/schema.rs @@ -3,6 +3,7 @@ use arrow::io::ipc::write::{default_ipc_fields, schema_to_bytes}; use base64::engine::general_purpose; use base64::Engine as _; use polars_error::{polars_bail, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; use super::super::ARROW_SCHEMA_META_KEY; use crate::arrow::write::decimal_length_from_precision; @@ -16,15 +17,15 @@ use crate::parquet::schema::Repetition; fn convert_field(field: Field) -> Field { Field { name: field.name, - data_type: convert_data_type(field.data_type), + dtype: convert_dtype(field.dtype), is_nullable: field.is_nullable, metadata: field.metadata, } } -fn convert_data_type(data_type: ArrowDataType) -> ArrowDataType { +fn convert_dtype(dtype: ArrowDataType) -> ArrowDataType { use ArrowDataType as D; - match data_type { + match dtype { D::LargeList(field) => D::LargeList(Box::new(convert_field(*field))), D::Struct(mut fields) => { for field in &mut fields { @@ -34,13 +35,13 @@ fn convert_data_type(data_type: ArrowDataType) -> ArrowDataType { }, D::BinaryView => D::LargeBinary, D::Utf8View => D::LargeUtf8, - D::Dictionary(it, data_type, sorted) => { - let dtype = convert_data_type(*data_type); + D::Dictionary(it, dtype, sorted) => { + let dtype = convert_dtype(*dtype); D::Dictionary(it, Box::new(dtype), sorted) }, - D::Extension(name, data_type, metadata) => { - let data_type = convert_data_type(*data_type); - D::Extension(name, Box::new(data_type), metadata) + D::Extension(name, dtype, metadata) => { + let dtype = convert_dtype(*dtype); + D::Extension(name, Box::new(dtype), metadata) }, dt => dt, } @@ -48,16 +49,15 @@ fn convert_data_type(data_type: ArrowDataType) -> ArrowDataType { pub fn schema_to_metadata_key(schema: &ArrowSchema) -> KeyValue { // Convert schema until more arrow readers are aware of binview - let serialized_schema = if schema.fields.iter().any(|field| field.data_type.is_view()) { - let fields = schema - .fields - .iter() + let serialized_schema = if schema.iter_values().any(|field| field.dtype.is_view()) { + let schema = schema + .iter_values() .map(|field| convert_field(field.clone())) - .collect::>(); - let schema = ArrowSchema::from(fields); - schema_to_bytes(&schema, &default_ipc_fields(&schema.fields)) + .map(|x| (x.name.clone(), x)) + .collect(); + schema_to_bytes(&schema, &default_ipc_fields(schema.iter_values())) } else { - schema_to_bytes(schema, &default_ipc_fields(&schema.fields)) + schema_to_bytes(schema, &default_ipc_fields(schema.iter_values())) }; // manually prepending the length to the schema as arrow uses the legacy IPC format @@ -85,7 +85,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { Repetition::Required }; // create type from field - match field.data_type().to_logical_type() { + match field.dtype().to_logical_type() { ArrowDataType::Null => Ok(ParquetType::try_from_primitive( name, PhysicalType::Int32, @@ -303,7 +303,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { )) }, ArrowDataType::Dictionary(_, value, _) => { - let dict_field = Field::new(name.as_str(), value.as_ref().clone(), field.is_nullable); + let dict_field = Field::new(name.clone(), value.as_ref().clone(), field.is_nullable); to_parquet_type(&dict_field) }, ArrowDataType::FixedSizeBinary(size) => Ok(ParquetType::try_from_primitive( @@ -392,7 +392,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { | ArrowDataType::FixedSizeList(f, _) | ArrowDataType::LargeList(f) => { let mut f = f.clone(); - f.name = "element".to_string(); + f.name = PlSmallStr::from_static("element"); Ok(ParquetType::from_group( name, @@ -400,7 +400,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { Some(GroupConvertedType::List), Some(GroupLogicalType::List), vec![ParquetType::from_group( - "list".to_string(), + PlSmallStr::from_static("list"), Repetition::Repeated, None, None, @@ -416,7 +416,7 @@ pub fn to_parquet_type(field: &Field) -> PolarsResult { Some(GroupConvertedType::Map), Some(GroupLogicalType::Map), vec![ParquetType::from_group( - "map".to_string(), + PlSmallStr::from_static("map"), Repetition::Repeated, None, None, diff --git a/crates/polars-parquet/src/arrow/write/sink.rs b/crates/polars-parquet/src/arrow/write/sink.rs index ca93975a4c46..3c60ff9e9f70 100644 --- a/crates/polars-parquet/src/arrow/write/sink.rs +++ b/crates/polars-parquet/src/arrow/write/sink.rs @@ -1,13 +1,13 @@ use std::pin::Pin; use std::task::Poll; -use ahash::AHashMap; use arrow::array::Array; use arrow::datatypes::ArrowSchema; use arrow::record_batch::RecordBatchT; use futures::future::BoxFuture; use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink, TryFutureExt}; use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; +use polars_utils::aliases::PlHashMap; use super::file::add_arrow_schema; use super::{Encoding, SchemaDescriptor, WriteOptions}; @@ -26,7 +26,7 @@ pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { schema: ArrowSchema, parquet_schema: SchemaDescriptor, /// Key-value metadata that will be written to the file on close. - pub metadata: AHashMap>, + pub metadata: PlHashMap>, } impl<'a, W> FileSink<'a, W> @@ -45,7 +45,7 @@ where encodings: Vec>, options: WriteOptions, ) -> PolarsResult { - if encodings.len() != schema.fields.len() { + if encodings.len() != schema.len() { polars_bail!(InvalidOperation: "The number of encodings must equal the number of fields".to_string(), ) @@ -69,7 +69,7 @@ where schema, encodings, parquet_schema, - metadata: AHashMap::default(), + metadata: PlHashMap::default(), }) } @@ -120,7 +120,7 @@ where self: Pin<&mut Self>, item: RecordBatchT>, ) -> Result<(), Self::Error> { - if self.schema.fields.len() != item.arrays().len() { + if self.schema.len() != item.arrays().len() { polars_bail!(InvalidOperation: "The number of arrays in the chunk must equal the number of fields in the schema" ) diff --git a/crates/polars-parquet/src/arrow/write/utils.rs b/crates/polars-parquet/src/arrow/write/utils.rs index bbbe177af648..422732a211a4 100644 --- a/crates/polars-parquet/src/arrow/write/utils.rs +++ b/crates/polars-parquet/src/arrow/write/utils.rs @@ -92,7 +92,7 @@ pub fn build_plain_page( max_def_level: 0, max_rep_level: 0, }, - Some(num_rows), + num_rows, )) } @@ -134,16 +134,18 @@ impl> Iterator for ExactSizedIter { } } +impl> std::iter::ExactSizeIterator for ExactSizedIter {} + /// Returns the number of bits needed to bitpack `max` #[inline] pub fn get_bit_width(max: u64) -> u32 { 64 - max.leading_zeros() } -pub(super) fn invalid_encoding(encoding: Encoding, data_type: &ArrowDataType) -> PolarsError { +pub(super) fn invalid_encoding(encoding: Encoding, dtype: &ArrowDataType) -> PolarsError { polars_err!(InvalidOperation: "Datatype {:?} cannot be encoded by {:?} encoding", - data_type, + dtype, encoding ) } diff --git a/crates/polars-parquet/src/parquet/bloom_filter/read.rs b/crates/polars-parquet/src/parquet/bloom_filter/read.rs index 5ebbc29f1218..deda00b36272 100644 --- a/crates/polars-parquet/src/parquet/bloom_filter/read.rs +++ b/crates/polars-parquet/src/parquet/bloom_filter/read.rs @@ -7,14 +7,14 @@ use parquet_format_safe::{ }; use crate::parquet::error::ParquetResult; -use crate::parquet::metadata::ColumnChunkMetaData; +use crate::parquet::metadata::ColumnChunkMetadata; -/// Reads the bloom filter associated to [`ColumnChunkMetaData`] into `bitset`. +/// Reads the bloom filter associated to [`ColumnChunkMetadata`] into `bitset`. /// Results in an empty `bitset` if there is no associated bloom filter or the algorithm is not supported. /// # Error /// Errors if the column contains no metadata or the filter can't be read or deserialized. pub fn read( - column_metadata: &ColumnChunkMetaData, + column_metadata: &ColumnChunkMetadata, mut reader: &mut R, bitset: &mut Vec, ) -> ParquetResult<()> { diff --git a/crates/polars-parquet/src/parquet/compression.rs b/crates/polars-parquet/src/parquet/compression.rs index 7798af585b7b..41bfb5f557bf 100644 --- a/crates/polars-parquet/src/parquet/compression.rs +++ b/crates/polars-parquet/src/parquet/compression.rs @@ -26,6 +26,7 @@ fn inner_compress< /// Compresses data stored in slice `input_buf` and writes the compressed result /// to `output_buf`. +/// /// Note that you'll need to call `clear()` before reusing the same `output_buf` /// across different `compress` calls. #[allow(unused_variables)] diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs index ce7fa301a7b4..6e37507d137f 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/decode.rs @@ -1,5 +1,5 @@ use super::{Packed, Unpackable, Unpacked}; -use crate::parquet::error::ParquetError; +use crate::parquet::error::{ParquetError, ParquetResult}; /// An [`Iterator`] of [`Unpackable`] unpacked from a bitpacked slice of bytes. /// # Implementation @@ -9,34 +9,18 @@ pub struct Decoder<'a, T: Unpackable> { packed: std::slice::Chunks<'a, u8>, num_bits: usize, /// number of items - length: usize, + pub(crate) length: usize, _pd: std::marker::PhantomData, } -#[derive(Debug)] -pub struct DecoderIter { - buffer: Vec, - idx: usize, -} - -impl Iterator for DecoderIter { - type Item = T; - - fn next(&mut self) -> Option { - if self.idx >= self.buffer.len() { - return None; +impl<'a, T: Unpackable> Default for Decoder<'a, T> { + fn default() -> Self { + Self { + packed: [].chunks(1), + num_bits: 0, + length: 0, + _pd: std::marker::PhantomData, } - - let value = self.buffer[self.idx]; - self.idx += 1; - - Some(value) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.buffer.len() - self.idx; - - (len, Some(len)) } } @@ -57,18 +41,43 @@ impl<'a, T: Unpackable> Decoder<'a, T> { Self::try_new(packed, num_bits, length).unwrap() } - pub fn collect_into_iter(self) -> DecoderIter { - let mut buffer = Vec::new(); - self.collect_into(&mut buffer); - DecoderIter { buffer, idx: 0 } + /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. + /// + /// `num_bits` is allowed to be `0`. + pub fn new_allow_zero(packed: &'a [u8], num_bits: usize, length: usize) -> Self { + Self::try_new_allow_zero(packed, num_bits, length).unwrap() } - pub fn num_bits(&self) -> usize { - self.num_bits + /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. + /// + /// `num_bits` is allowed to be `0`. + pub fn try_new_allow_zero( + packed: &'a [u8], + num_bits: usize, + length: usize, + ) -> ParquetResult { + let block_size = std::mem::size_of::() * num_bits; + + if packed.len() * 8 < length * num_bits { + return Err(ParquetError::oos(format!( + "Unpacking {length} items with a number of bits {num_bits} requires at least {} bytes.", + length * num_bits / 8 + ))); + } + + debug_assert!(num_bits != 0 || packed.is_empty()); + let packed = packed.chunks(block_size.max(1)); + + Ok(Self { + length, + packed, + num_bits, + _pd: Default::default(), + }) } /// Returns a [`Decoder`] with `T` encoded in `packed` with `num_bits`. - pub fn try_new(packed: &'a [u8], num_bits: usize, length: usize) -> Result { + pub fn try_new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { let block_size = std::mem::size_of::() * num_bits; if num_bits == 0 { @@ -91,11 +100,16 @@ impl<'a, T: Unpackable> Decoder<'a, T> { _pd: Default::default(), }) } + + pub fn num_bits(&self) -> usize { + self.num_bits + } } /// A iterator over the exact chunks in a [`Decoder`]. /// /// The remainder can be accessed using `remainder` or `next_inexact`. +#[derive(Debug)] pub struct ChunkedDecoder<'a, 'b, T: Unpackable> { pub(crate) decoder: &'b mut Decoder<'a, T>, } diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs index ef6c5313ba26..94f310d28f14 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/mod.rs @@ -57,7 +57,7 @@ mod encode; mod pack; mod unpack; -pub use decode::{Decoder, DecoderIter}; +pub use decode::Decoder; pub use encode::{encode, encode_pack}; /// A byte slice (e.g. `[u8; 8]`) denoting types that represent complete packs. diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs index ee21a5094718..261e84ce2e23 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/decoder.rs @@ -1,246 +1,812 @@ +//! This module implements the `DELTA_BINARY_PACKED` encoding. +//! +//! For performance reasons this is done without iterators. Instead, we have `gather_n` functions +//! and a `DeltaGatherer` trait. These allow efficient decoding and mapping of the decoded values. +//! +//! Full information on the delta encoding can be found on the Apache Parquet Format repository. +//! +//! +//! +//! Delta encoding compresses sequential integer values by encoding the first value and the +//! differences between consequentive values. This variant encodes the data into `Block`s and +//! `MiniBlock`s. +//! +//! - A `Block` contains a minimum delta, bitwidths and one or more miniblocks. +//! - A `MiniBlock` contains many deltas that are encoded in [`bitpacked`] encoding. +//! +//! The decoder keeps track of the last value and calculates a new value with the following +//! function. +//! +//! ```text +//! NextValue(Delta) = { +//! Value = Decoder.LastValue + Delta + Block.MinDelta +//! Decoder.LastValue = Value +//! return Value +//! } +//! ``` +//! +//! Note that all these additions need to be wrapping. + use super::super::{bitpacked, uleb128, zigzag_leb128}; -use crate::parquet::encoding::ceil8; +use super::lin_natural_sum; +use crate::parquet::encoding::bitpacked::{Unpackable, Unpacked}; use crate::parquet::error::{ParquetError, ParquetResult}; -/// An [`Iterator`] of [`i64`] +const MAX_BITWIDTH: u8 = 64; + +/// Decoder of parquets' `DELTA_BINARY_PACKED`. +#[derive(Debug)] +pub struct Decoder<'a> { + num_miniblocks_per_block: usize, + values_per_block: usize, + + values_remaining: usize, + + last_value: i64, + + values: &'a [u8], + + block: Block<'a>, +} + #[derive(Debug)] struct Block<'a> { - // this is the minimum delta that must be added to every value. min_delta: i64, - _num_mini_blocks: usize, - /// Number of values that each mini block has. - values_per_mini_block: usize, - bitwidths: std::slice::Iter<'a, u8>, - values: &'a [u8], - remaining: usize, // number of elements - current_index: usize, // invariant: < values_per_mini_block - // None represents a relative delta of zero, in which case there is no miniblock. - current_miniblock: Option>, - // number of bytes consumed. - consumed_bytes: usize, + + /// Bytes that give the `num_bits` for the [`bitpacked::Decoder`]. + /// + /// Invariant: `bitwidth[i] <= MAX_BITWIDTH` for all `i` + bitwidths: &'a [u8], + values_remaining: usize, + miniblock: MiniBlock<'a>, } -impl<'a> Block<'a> { - pub fn try_new( - mut values: &'a [u8], - num_mini_blocks: usize, - values_per_mini_block: usize, - length: usize, - ) -> ParquetResult { - let length = std::cmp::min(length, num_mini_blocks * values_per_mini_block); - - let mut consumed_bytes = 0; - let (min_delta, consumed) = zigzag_leb128::decode(values); - consumed_bytes += consumed; - values = &values[consumed..]; - - if num_mini_blocks > values.len() { - return Err(ParquetError::oos( - "Block must contain at least num_mini_blocks bytes (the bitwidths)", - )); +#[derive(Debug)] +struct MiniBlock<'a> { + decoder: bitpacked::Decoder<'a, u64>, + buffered: ::Unpacked, + unpacked_start: usize, + unpacked_end: usize, +} + +struct SkipGatherer; +pub(crate) struct SumGatherer(pub(crate) usize); + +pub trait DeltaGatherer { + type Target: std::fmt::Debug; + + fn target_len(&self, target: &Self::Target) -> usize; + fn target_reserve(&self, target: &mut Self::Target, n: usize); + + /// Gather one element with value `v` into `target`. + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()>; + + /// Gather `num_repeats` elements into `target`. + /// + /// The first value is `v` and the `n`-th value is `v + (n-1)*delta`. + fn gather_constant( + &mut self, + target: &mut Self::Target, + v: i64, + delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + for i in 0..num_repeats { + self.gather_one(target, v + (i as i64) * delta)?; } - let (bitwidths, remaining) = values.split_at(num_mini_blocks); - consumed_bytes += num_mini_blocks; - values = remaining; + Ok(()) + } + /// Gather a `slice` of elements into `target`. + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + for &v in slice { + self.gather_one(target, v)?; + } + Ok(()) + } + /// Gather a `chunk` of elements into `target`. + fn gather_chunk(&mut self, target: &mut Self::Target, chunk: &[i64; 64]) -> ParquetResult<()> { + self.gather_slice(target, chunk) + } +} - let mut block = Block { - min_delta, - _num_mini_blocks: num_mini_blocks, - values_per_mini_block, - bitwidths: bitwidths.iter(), - remaining: length, - values, - current_index: 0, - current_miniblock: None, - consumed_bytes, - }; +impl DeltaGatherer for SkipGatherer { + type Target = usize; - // Set up first mini-block - block.advance_miniblock()?; + fn target_len(&self, target: &Self::Target) -> usize { + *target + } + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} - Ok(block) + fn gather_one(&mut self, target: &mut Self::Target, _v: i64) -> ParquetResult<()> { + *target += 1; + Ok(()) } + fn gather_constant( + &mut self, + target: &mut Self::Target, + _v: i64, + _delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + *target += num_repeats; + Ok(()) + } + fn gather_chunk(&mut self, target: &mut Self::Target, chunk: &[i64; 64]) -> ParquetResult<()> { + *target += chunk.len(); + Ok(()) + } + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + *target += slice.len(); + Ok(()) + } +} - fn advance_miniblock(&mut self) -> ParquetResult<()> { - // unwrap is ok: we sliced it by num_mini_blocks in try_new - let num_bits = self.bitwidths.next().copied().unwrap() as usize; +impl DeltaGatherer for SumGatherer { + type Target = usize; - self.current_miniblock = if num_bits > 0 { - let length = std::cmp::min(self.remaining, self.values_per_mini_block); + fn target_len(&self, _target: &Self::Target) -> usize { + self.0 + } + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} - let miniblock_length = ceil8(self.values_per_mini_block * num_bits); - if miniblock_length > self.values.len() { - return Err(ParquetError::oos( - "block must contain at least miniblock_length bytes (the mini block)", - )); - } - let (miniblock, remainder) = self.values.split_at(miniblock_length); - - self.values = remainder; - self.consumed_bytes += miniblock_length; - - Some( - bitpacked::Decoder::try_new(miniblock, num_bits, length) - .unwrap() - .collect_into_iter(), - ) - } else { - None - }; - self.current_index = 0; + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + if v < 0 { + return Err(ParquetError::oos(format!( + "Invalid delta encoding length {v}" + ))); + } + *target += v as usize; + self.0 += 1; Ok(()) } -} + fn gather_constant( + &mut self, + target: &mut Self::Target, + v: i64, + delta: i64, + num_repeats: usize, + ) -> ParquetResult<()> { + if v < 0 || (delta < 0 && num_repeats > 0 && (num_repeats - 1) as i64 * delta + v < 0) { + return Err(ParquetError::oos("Invalid delta encoding length")); + } -impl<'a> Iterator for Block<'a> { - type Item = Result; + *target += lin_natural_sum(v, delta, num_repeats) as usize; - fn next(&mut self) -> Option { - if self.remaining == 0 { - return None; + Ok(()) + } + fn gather_slice(&mut self, target: &mut Self::Target, slice: &[i64]) -> ParquetResult<()> { + let min = slice.iter().copied().min().unwrap_or_default(); + if min < 0 { + return Err(ParquetError::oos(format!( + "Invalid delta encoding length {min}" + ))); } - let result = self.min_delta - + self - .current_miniblock - .as_mut() - .map(|x| x.next().unwrap_or_default()) - .unwrap_or(0) as i64; - self.current_index += 1; - self.remaining -= 1; - - if self.remaining > 0 && self.current_index == self.values_per_mini_block { - if let Err(e) = self.advance_miniblock() { - return Some(Err(e)); - } + + *target += slice.iter().copied().map(|v| v as usize).sum::(); + self.0 += slice.len(); + Ok(()) + } + fn gather_chunk(&mut self, target: &mut Self::Target, chunk: &[i64; 64]) -> ParquetResult<()> { + let min = chunk.iter().copied().min().unwrap_or_default(); + if min < 0 { + return Err(ParquetError::oos(format!( + "Invalid delta encoding length {min}" + ))); } + *target += chunk.iter().copied().map(|v| v as usize).sum::(); + self.0 += chunk.len(); + Ok(()) + } +} - Some(Ok(result)) +/// Gather the rest of the [`bitpacked::Decoder`] into `target` +fn gather_bitpacked( + target: &mut G::Target, + min_delta: i64, + last_value: &mut i64, + mut decoder: bitpacked::Decoder, + gatherer: &mut G, +) -> ParquetResult<()> { + let mut chunked = decoder.chunked(); + for mut chunk in &mut chunked { + for value in &mut chunk { + *last_value = last_value + .wrapping_add(*value as i64) + .wrapping_add(min_delta); + *value = *last_value as u64; + } + + let chunk = bytemuck::cast_ref(&chunk); + gatherer.gather_chunk(target, chunk)?; + } + + if let Some((mut chunk, length)) = chunked.next_inexact() { + let slice = &mut chunk[..length]; + + for value in slice.iter_mut() { + *last_value = last_value + .wrapping_add(*value as i64) + .wrapping_add(min_delta); + *value = *last_value as u64; + } + + let slice = bytemuck::cast_slice(slice); + gatherer.gather_slice(target, slice)?; } + + Ok(()) } -/// Decoder of parquets' `DELTA_BINARY_PACKED`. Implements `Iterator`. -/// # Implementation -/// This struct does not allocate on the heap. -#[derive(Debug)] -pub struct Decoder<'a> { - num_mini_blocks: usize, - values_per_mini_block: usize, - values_remaining: usize, - next_value: i64, - values: &'a [u8], - current_block: Option>, - // the total number of bytes consumed up to a given point, excluding the bytes on the current_block - consumed_bytes: usize, +/// Gather an entire [`MiniBlock`] into `target` +fn gather_miniblock( + target: &mut G::Target, + min_delta: i64, + bitwidth: u8, + values: &[u8], + values_per_miniblock: usize, + last_value: &mut i64, + gatherer: &mut G, +) -> ParquetResult<()> { + let bitwidth = bitwidth as usize; + + if bitwidth == 0 { + let v = last_value.wrapping_add(min_delta); + gatherer.gather_constant(target, v, min_delta, values_per_miniblock)?; + *last_value = last_value.wrapping_add(min_delta * values_per_miniblock as i64); + return Ok(()); + } + + debug_assert!(bitwidth <= 64); + debug_assert_eq!((bitwidth * values_per_miniblock).div_ceil(8), values.len()); + + let start_length = gatherer.target_len(target); + gather_bitpacked( + target, + min_delta, + last_value, + bitpacked::Decoder::new(values, bitwidth, values_per_miniblock), + gatherer, + )?; + let target_length = gatherer.target_len(target); + + debug_assert_eq!(target_length - start_length, values_per_miniblock); + + Ok(()) +} + +/// Gather an entire [`Block`] into `target` +fn gather_block<'a, G: DeltaGatherer>( + target: &mut G::Target, + num_miniblocks: usize, + values_per_miniblock: usize, + mut values: &'a [u8], + last_value: &mut i64, + gatherer: &mut G, +) -> ParquetResult<&'a [u8]> { + let (min_delta, consumed) = zigzag_leb128::decode(values); + values = &values[consumed..]; + let bitwidths; + (bitwidths, values) = values + .split_at_checked(num_miniblocks) + .ok_or_else(|| ParquetError::oos("Not enough bitwidths available in delta encoding"))?; + + gatherer.target_reserve(target, num_miniblocks * values_per_miniblock); + for &bitwidth in bitwidths { + let miniblock; + (miniblock, values) = values + .split_at_checked((bitwidth as usize * values_per_miniblock).div_ceil(8)) + .ok_or_else(|| ParquetError::oos("Not enough bytes for miniblock in delta encoding"))?; + gather_miniblock( + target, + min_delta, + bitwidth, + miniblock, + values_per_miniblock, + last_value, + gatherer, + )?; + } + + Ok(values) } impl<'a> Decoder<'a> { - pub fn try_new(mut values: &'a [u8]) -> Result { - let mut consumed_bytes = 0; - let (block_size, consumed) = uleb128::decode(values); - consumed_bytes += consumed; - assert_eq!(block_size % 128, 0); - values = &values[consumed..]; - let (num_mini_blocks, consumed) = uleb128::decode(values); - let num_mini_blocks = num_mini_blocks as usize; - consumed_bytes += consumed; - values = &values[consumed..]; + pub fn try_new(mut values: &'a [u8]) -> ParquetResult<(Self, &'a [u8])> { + let header_err = || ParquetError::oos("Insufficient bytes for Delta encoding header"); + + // header: + // + + let (values_per_block, consumed) = uleb128::decode(values); + let values_per_block = values_per_block as usize; + values = values.get(consumed..).ok_or_else(header_err)?; + + assert_eq!(values_per_block % 128, 0); + + let (num_miniblocks_per_block, consumed) = uleb128::decode(values); + let num_miniblocks_per_block = num_miniblocks_per_block as usize; + values = values.get(consumed..).ok_or_else(header_err)?; + let (total_count, consumed) = uleb128::decode(values); let total_count = total_count as usize; - consumed_bytes += consumed; - values = &values[consumed..]; + values = values.get(consumed..).ok_or_else(header_err)?; + let (first_value, consumed) = zigzag_leb128::decode(values); - consumed_bytes += consumed; - values = &values[consumed..]; - - let values_per_mini_block = block_size as usize / num_mini_blocks; - assert_eq!(values_per_mini_block % 8, 0); - - // If we only have one value (first_value), there are no blocks. - let current_block = if total_count > 1 { - Some(Block::try_new( - values, - num_mini_blocks, - values_per_mini_block, - total_count - 1, - )?) - } else { - None - }; + values = values.get(consumed..).ok_or_else(header_err)?; + + assert_eq!(values_per_block % num_miniblocks_per_block, 0); + assert_eq!((values_per_block / num_miniblocks_per_block) % 32, 0); + + let values_per_miniblock = values_per_block / num_miniblocks_per_block; + assert_eq!(values_per_miniblock % 8, 0); - Ok(Self { - num_mini_blocks, - values_per_mini_block, - values_remaining: total_count, - next_value: first_value, + // We skip over all the values to determine where the slice stops. + // + // This also has the added benefit of error checking in advance, thus we can unwrap in + // other places. + + let mut rem = values; + if total_count > 1 { + let mut num_values_left = total_count - 1; + while num_values_left > 0 { + // If the number of values is does not need all the miniblocks anymore, we need to + // ignore the later miniblocks and regard them as having bitwidth = 0. + // + // Quoted from the specification: + // + // > If, in the last block, less than miniblocks + // > are needed to store the values, the bytes storing the bit widths of the + // > unneeded miniblocks are still present, their value should be zero, but readers + // > must accept arbitrary values as well. There are no additional padding bytes for + // > the miniblock bodies though, as if their bit widths were 0 (regardless of the + // > actual byte values). The reader knows when to stop reading by keeping track of + // > the number of values read. + let num_remaining_mini_blocks = usize::min( + num_miniblocks_per_block, + num_values_left.div_ceil(values_per_miniblock), + ); + + // block: + // + + let (_, consumed) = zigzag_leb128::decode(rem); + rem = rem.get(consumed..).ok_or_else(|| { + ParquetError::oos("No min-delta value in delta encoding miniblock") + })?; + + if rem.len() < num_miniblocks_per_block { + return Err(ParquetError::oos( + "Not enough bitwidths available in delta encoding", + )); + } + if let Some(err_bitwidth) = rem + .get(..num_remaining_mini_blocks) + .expect("num_remaining_mini_blocks <= num_miniblocks_per_block") + .iter() + .copied() + .find(|&bitwidth| bitwidth > MAX_BITWIDTH) + { + return Err(ParquetError::oos(format!( + "Delta encoding miniblock with bitwidth {err_bitwidth} higher than maximum {MAX_BITWIDTH} bits", + ))); + } + + let num_bitpacking_bytes = rem[..num_remaining_mini_blocks] + .iter() + .copied() + .map(|bitwidth| (bitwidth as usize * values_per_miniblock).div_ceil(8)) + .sum::(); + + rem = rem + .get(num_miniblocks_per_block + num_bitpacking_bytes..) + .ok_or_else(|| { + ParquetError::oos( + "Not enough bytes for all bitpacked values in delta encoding", + ) + })?; + + num_values_left = num_values_left.saturating_sub(values_per_block); + } + } + + let values = &values[..values.len() - rem.len()]; + + let decoder = Self { + num_miniblocks_per_block, + values_per_block, + values_remaining: total_count.saturating_sub(1), + last_value: first_value, values, - current_block, - consumed_bytes, - }) + + block: Block { + // @NOTE: + // We add one delta=0 into the buffered block which allows us to + // prepend like the `first_value` is just any normal value. + // + // This is a bit of a hack, but makes the rest of the logic + // **A LOT** simpler. + values_remaining: usize::from(total_count > 0), + min_delta: 0, + bitwidths: &[], + miniblock: MiniBlock { + decoder: bitpacked::Decoder::try_new_allow_zero(&[], 0, 1)?, + buffered: ::Unpacked::zero(), + unpacked_start: 0, + unpacked_end: 0, + }, + }, + }; + + Ok((decoder, rem)) } - /// Returns the total number of bytes consumed up to this point by [`Decoder`]. - pub fn consumed_bytes(&self) -> usize { - self.consumed_bytes + self.current_block.as_ref().map_or(0, |b| b.consumed_bytes) + /// Consume a new [`Block`] from `self.values`. + fn consume_block(&mut self) { + // @NOTE: All the panics here should be prevented in the `Decoder::try_new`. + + debug_assert!(!self.values.is_empty()); + + let values_per_miniblock = self.values_per_miniblock(); + + let length = usize::min(self.values_remaining, self.values_per_block); + let actual_num_miniblocks = usize::min( + self.num_miniblocks_per_block, + length.div_ceil(values_per_miniblock), + ); + + debug_assert!(actual_num_miniblocks > 0); + + // + + let (min_delta, consumed) = zigzag_leb128::decode(self.values); + + self.values = &self.values[consumed..]; + let (bitwidths, remainder) = self.values.split_at(self.num_miniblocks_per_block); + + let first_bitwidth = bitwidths[0]; + let bitwidths = &bitwidths[1..actual_num_miniblocks]; + debug_assert!(first_bitwidth <= MAX_BITWIDTH); + let first_bitwidth = first_bitwidth as usize; + + let values_in_first_miniblock = usize::min(length, values_per_miniblock); + let num_allocated_bytes = (first_bitwidth * values_per_miniblock).div_ceil(8); + let num_actual_bytes = (first_bitwidth * values_in_first_miniblock).div_ceil(8); + let (bytes, remainder) = remainder.split_at(num_allocated_bytes); + let bytes = &bytes[..num_actual_bytes]; + + let decoder = + bitpacked::Decoder::new_allow_zero(bytes, first_bitwidth, values_in_first_miniblock); + + self.block = Block { + min_delta, + bitwidths, + values_remaining: length, + miniblock: MiniBlock { + decoder, + // We can leave this as it should not be read before it is updated + buffered: self.block.miniblock.buffered, + unpacked_start: 0, + unpacked_end: 0, + }, + }; + + self.values_remaining -= length; + self.values = remainder; } - fn load_delta(&mut self) -> Result { - // At this point we must have at least one block and value available - let current_block = self.current_block.as_mut().unwrap(); - if let Some(x) = current_block.next() { - x - } else { - // load next block - self.values = &self.values[current_block.consumed_bytes..]; - self.consumed_bytes += current_block.consumed_bytes; + /// Gather `n` elements from the current [`MiniBlock`] to `target` + fn gather_miniblock_n_into( + &mut self, + target: &mut G::Target, + mut n: usize, + gatherer: &mut G, + ) -> ParquetResult<()> { + debug_assert!(n > 0); + debug_assert!(self.miniblock_len() >= n); + + // If the `num_bits == 0`, the delta is constant and equal to `min_delta`. The + // `bitpacked::Decoder` basically only keeps track of the length. + if self.block.miniblock.decoder.num_bits() == 0 { + let num_repeats = usize::min(self.miniblock_len(), n); + let v = self.last_value.wrapping_add(self.block.min_delta); + gatherer.gather_constant(target, v, self.block.min_delta, num_repeats)?; + self.last_value = self + .last_value + .wrapping_add(self.block.min_delta * num_repeats as i64); + self.block.miniblock.decoder.length -= num_repeats; + return Ok(()); + } - let next_block = Block::try_new( - self.values, - self.num_mini_blocks, - self.values_per_mini_block, - self.values_remaining, + if self.block.miniblock.unpacked_start < self.block.miniblock.unpacked_end { + let length = usize::min( + n, + self.block.miniblock.unpacked_end - self.block.miniblock.unpacked_start, ); - match next_block { - Ok(mut next_block) => { - let delta = next_block - .next() - .ok_or_else(|| ParquetError::oos("Missing block"))?; - self.current_block = Some(next_block); - delta - }, - Err(e) => Err(e), + self.block.miniblock.buffered + [self.block.miniblock.unpacked_start..self.block.miniblock.unpacked_start + length] + .iter_mut() + .for_each(|v| { + self.last_value = self + .last_value + .wrapping_add(*v as i64) + .wrapping_add(self.block.min_delta); + *v = self.last_value as u64; + }); + gatherer.gather_slice( + target, + bytemuck::cast_slice( + &self.block.miniblock.buffered[self.block.miniblock.unpacked_start + ..self.block.miniblock.unpacked_start + length], + ), + )?; + n -= length; + self.block.miniblock.unpacked_start += length; + } + + if n == 0 { + return Ok(()); + } + + const ITEMS_PER_PACK: usize = <::Unpacked as Unpacked>::LENGTH; + for _ in 0..n / ITEMS_PER_PACK { + let mut chunk = self.block.miniblock.decoder.chunked().next().unwrap(); + chunk.iter_mut().for_each(|v| { + self.last_value = self + .last_value + .wrapping_add(*v as i64) + .wrapping_add(self.block.min_delta); + *v = self.last_value as u64; + }); + gatherer.gather_chunk(target, bytemuck::cast_ref(&chunk))?; + n -= ITEMS_PER_PACK; + } + + if n == 0 { + return Ok(()); + } + + let Some((chunk, len)) = self.block.miniblock.decoder.chunked().next_inexact() else { + debug_assert_eq!(n, 0); + self.block.miniblock.buffered = ::Unpacked::zero(); + self.block.miniblock.unpacked_start = 0; + self.block.miniblock.unpacked_end = 0; + return Ok(()); + }; + + self.block.miniblock.buffered = chunk; + self.block.miniblock.unpacked_start = 0; + self.block.miniblock.unpacked_end = len; + + if n > 0 { + let length = usize::min(n, self.block.miniblock.unpacked_end); + self.block.miniblock.buffered[..length] + .iter_mut() + .for_each(|v| { + self.last_value = self + .last_value + .wrapping_add(*v as i64) + .wrapping_add(self.block.min_delta); + *v = self.last_value as u64; + }); + gatherer.gather_slice( + target, + bytemuck::cast_slice(&self.block.miniblock.buffered[..length]), + )?; + self.block.miniblock.unpacked_start = length; + } + + Ok(()) + } + + /// Gather `n` elements from the current [`Block`] to `target` + fn gather_block_n_into( + &mut self, + target: &mut G::Target, + n: usize, + gatherer: &mut G, + ) -> ParquetResult<()> { + let values_per_miniblock = self.values_per_miniblock(); + + debug_assert!(n <= self.values_per_block); + debug_assert!(self.values_per_block >= values_per_miniblock); + debug_assert_eq!(self.values_per_block % values_per_miniblock, 0); + + let mut n = usize::min(self.block.values_remaining, n); + + if n == 0 { + return Ok(()); + } + + let miniblock_len = self.miniblock_len(); + if n < miniblock_len { + self.gather_miniblock_n_into(target, n, gatherer)?; + debug_assert_eq!(self.miniblock_len(), miniblock_len - n); + self.block.values_remaining -= n; + return Ok(()); + } + + if miniblock_len > 0 { + self.gather_miniblock_n_into(target, miniblock_len, gatherer)?; + n -= miniblock_len; + self.block.values_remaining -= miniblock_len; + } + + while n >= values_per_miniblock { + let bitwidth = self.block.bitwidths[0]; + self.block.bitwidths = &self.block.bitwidths[1..]; + + let miniblock; + (miniblock, self.values) = self + .values + .split_at((bitwidth as usize * values_per_miniblock).div_ceil(8)); + gather_miniblock( + target, + self.block.min_delta, + bitwidth, + miniblock, + values_per_miniblock, + &mut self.last_value, + gatherer, + )?; + n -= values_per_miniblock; + self.block.values_remaining -= values_per_miniblock; + } + + if n == 0 { + return Ok(()); + } + + if !self.block.bitwidths.is_empty() { + let bitwidth = self.block.bitwidths[0]; + self.block.bitwidths = &self.block.bitwidths[1..]; + + if bitwidth > MAX_BITWIDTH { + return Err(ParquetError::oos(format!( + "Delta encoding bitwidth '{bitwidth}' is larger than maximum {MAX_BITWIDTH})" + ))); + } + + let length = usize::min(values_per_miniblock, self.block.values_remaining); + + let num_allocated_bytes = (bitwidth as usize * values_per_miniblock).div_ceil(8); + let num_actual_bytes = (bitwidth as usize * length).div_ceil(8); + + let miniblock; + (miniblock, self.values) = + self.values + .split_at_checked(num_allocated_bytes) + .ok_or(ParquetError::oos( + "Not enough space for delta encoded miniblock", + ))?; + + let miniblock = &miniblock[..num_actual_bytes]; + + let decoder = + bitpacked::Decoder::try_new_allow_zero(miniblock, bitwidth as usize, length)?; + self.block.miniblock = MiniBlock { + decoder, + buffered: self.block.miniblock.buffered, + unpacked_start: 0, + unpacked_end: 0, + }; + + if n > 0 { + self.gather_miniblock_n_into(target, n, gatherer)?; + self.block.values_remaining -= n; } } + + Ok(()) } -} -impl<'a> Iterator for Decoder<'a> { - type Item = Result; + /// Gather `n` elements to `target` + pub fn gather_n_into( + &mut self, + target: &mut G::Target, + mut n: usize, + gatherer: &mut G, + ) -> ParquetResult<()> { + n = usize::min(n, self.len()); + + if n == 0 { + return Ok(()); + } + + let values_per_miniblock = self.values_per_block / self.num_miniblocks_per_block; + + let start_num_values_remaining = self.block.values_remaining; + if n <= self.block.values_remaining { + self.gather_block_n_into(target, n, gatherer)?; + debug_assert_eq!(self.block.values_remaining, start_num_values_remaining - n); + return Ok(()); + } + + n -= self.block.values_remaining; + self.gather_block_n_into(target, self.block.values_remaining, gatherer)?; + debug_assert_eq!(self.block.values_remaining, 0); + + while usize::min(n, self.values_remaining) >= self.values_per_block { + self.values = gather_block( + target, + self.num_miniblocks_per_block, + values_per_miniblock, + self.values, + &mut self.last_value, + gatherer, + )?; + n -= self.values_per_block; + self.values_remaining -= self.values_per_block; + } - fn next(&mut self) -> Option { - if self.values_remaining == 0 { - return None; + if n == 0 { + return Ok(()); } - let result = Some(Ok(self.next_value)); + self.consume_block(); + self.gather_block_n_into(target, n, gatherer)?; - self.values_remaining -= 1; - if self.values_remaining == 0 { - // do not try to load another block - return result; + Ok(()) + } + + pub fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + let mut gatherer = SkipGatherer; + self.gather_n_into(&mut 0usize, n, &mut gatherer) + } + + #[cfg(test)] + pub(crate) fn collect_n>( + &mut self, + e: &mut E, + n: usize, + ) -> ParquetResult<()> { + struct ExtendGatherer<'a, E: std::fmt::Debug + Extend>( + std::marker::PhantomData<&'a E>, + ); + + impl<'a, E: std::fmt::Debug + Extend> DeltaGatherer for ExtendGatherer<'a, E> { + type Target = (usize, &'a mut E); + + fn target_len(&self, target: &Self::Target) -> usize { + target.0 + } + + fn target_reserve(&self, _target: &mut Self::Target, _n: usize) {} + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + target.1.extend(Some(v)); + target.0 += 1; + Ok(()) + } } - let delta = match self.load_delta() { - Ok(delta) => delta, - Err(e) => return Some(Err(e)), - }; + let mut gatherer = ExtendGatherer(std::marker::PhantomData); + let mut target = (0, e); + + self.gather_n_into(&mut target, n, &mut gatherer) + } + + #[cfg(test)] + pub(crate) fn collect + Default>( + mut self, + ) -> ParquetResult { + let mut e = E::default(); + self.collect_n(&mut e, self.len())?; + Ok(e) + } + + pub fn len(&self) -> usize { + self.values_remaining + self.block.values_remaining + } - self.next_value += delta; - result + fn values_per_miniblock(&self) -> usize { + debug_assert_eq!(self.values_per_block % self.num_miniblocks_per_block, 0); + self.values_per_block / self.num_miniblocks_per_block } - fn size_hint(&self) -> (usize, Option) { - (self.values_remaining, Some(self.values_remaining)) + fn miniblock_len(&self) -> usize { + self.block.miniblock.unpacked_end - self.block.miniblock.unpacked_start + + self.block.miniblock.decoder.len() } } @@ -259,11 +825,11 @@ mod tests { // first_value: 2 <=z> 1 let data = &[128, 1, 4, 1, 2]; - let mut decoder = Decoder::try_new(data).unwrap(); - let r = decoder.by_ref().collect::, _>>().unwrap(); + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); assert_eq!(&r[..], &[1]); - assert_eq!(decoder.consumed_bytes(), 5); + assert_eq!(data.len() - rem.len(), 5); } #[test] @@ -280,12 +846,12 @@ mod tests { // bit_width: 0 let data = &[128, 1, 4, 5, 2, 2, 0, 0, 0, 0]; - let mut decoder = Decoder::try_new(data).unwrap(); - let r = decoder.by_ref().collect::, _>>().unwrap(); + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); assert_eq!(expected, r); - assert_eq!(decoder.consumed_bytes(), 10); + assert_eq!(data.len() - rem.len(), 10); } #[test] @@ -311,11 +877,11 @@ mod tests { 1, 2, 3, ]; - let mut decoder = Decoder::try_new(data).unwrap(); - let r = decoder.by_ref().collect::, _>>().unwrap(); + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); assert_eq!(expected, r); - assert_eq!(decoder.consumed_bytes(), data.len() - 3); + assert_eq!(rem, &[1, 2, 3]); } #[test] @@ -357,10 +923,11 @@ mod tests { -2, 2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, ]; - let mut decoder = Decoder::try_new(data).unwrap(); - let r = decoder.by_ref().collect::, _>>().unwrap(); + let (decoder, rem) = Decoder::try_new(data).unwrap(); + let r = decoder.collect::>().unwrap(); assert_eq!(&expected[..], &r[..]); - assert_eq!(decoder.consumed_bytes(), data.len() - 3); + assert_eq!(data.len() - rem.len(), data.len() - 3); + assert_eq!(rem.len(), 3); } } diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs index 9bdb861504d1..24b6ea6523b8 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/encoder.rs @@ -5,49 +5,60 @@ use crate::parquet::encoding::ceil8; /// # Implementation /// * This function does not allocate on the heap. /// * The number of mini-blocks is always 1. This may change in the future. -pub fn encode>(mut iterator: I, buffer: &mut Vec) { - let block_size = 128; - let mini_blocks = 1; +pub fn encode>( + mut iterator: I, + buffer: &mut Vec, + num_miniblocks_per_block: usize, +) { + const BLOCK_SIZE: usize = 256; + assert!([1, 2, 4].contains(&num_miniblocks_per_block)); + let values_per_miniblock = BLOCK_SIZE / num_miniblocks_per_block; let mut container = [0u8; 10]; - let encoded_len = uleb128::encode(block_size, &mut container); + let encoded_len = uleb128::encode(BLOCK_SIZE as u64, &mut container); buffer.extend_from_slice(&container[..encoded_len]); - let encoded_len = uleb128::encode(mini_blocks, &mut container); + let encoded_len = uleb128::encode(num_miniblocks_per_block as u64, &mut container); buffer.extend_from_slice(&container[..encoded_len]); - let length = iterator.size_hint().1.unwrap(); + let length = iterator.len(); let encoded_len = uleb128::encode(length as u64, &mut container); buffer.extend_from_slice(&container[..encoded_len]); - let mut values = [0i64; 128]; - let mut deltas = [0u64; 128]; + let mut values = [0i64; BLOCK_SIZE]; + let mut deltas = [0u64; BLOCK_SIZE]; + let mut num_bits = [0u8; 4]; let first_value = iterator.next().unwrap_or_default(); let (container, encoded_len) = zigzag_leb128::encode(first_value); buffer.extend_from_slice(&container[..encoded_len]); let mut prev = first_value; - let mut length = iterator.size_hint().1.unwrap(); + let mut length = iterator.len(); while length != 0 { let mut min_delta = i64::MAX; let mut max_delta = i64::MIN; - let mut num_bits = 0; - for (i, integer) in (0..128).zip(&mut iterator) { - let delta = integer - prev; + for (i, integer) in iterator.by_ref().enumerate().take(BLOCK_SIZE) { + if i % values_per_miniblock == 0 { + min_delta = i64::MAX; + max_delta = i64::MIN + } + + let delta = integer.wrapping_sub(prev); min_delta = min_delta.min(delta); max_delta = max_delta.max(delta); - num_bits = 64 - (max_delta - min_delta).leading_zeros(); + let miniblock_idx = i / values_per_miniblock; + num_bits[miniblock_idx] = (64 - max_delta.abs_diff(min_delta).leading_zeros()) as u8; values[i] = delta; prev = integer; } - let consumed = std::cmp::min(length - iterator.size_hint().1.unwrap(), 128); - length = iterator.size_hint().1.unwrap(); + let consumed = std::cmp::min(length - iterator.len(), BLOCK_SIZE); + length = iterator.len(); let values = &values[..consumed]; values.iter().zip(deltas.iter_mut()).for_each(|(v, delta)| { - *delta = (v - min_delta) as u64; + *delta = v.wrapping_sub(min_delta) as u64; }); // @@ -55,19 +66,32 @@ pub fn encode>(mut iterator: I, buffer: &mut Vec) { buffer.extend_from_slice(&container[..encoded_len]); // one miniblock => 1 byte - buffer.push(num_bits as u8); - write_miniblock(buffer, num_bits as usize, deltas); + let mut values_remaining = consumed; + buffer.extend_from_slice(&num_bits[..num_miniblocks_per_block]); + for i in 0..num_miniblocks_per_block { + if values_remaining == 0 { + break; + } + + values_remaining = values_remaining.saturating_sub(values_per_miniblock); + write_miniblock( + buffer, + num_bits[i], + &deltas[i * values_per_miniblock..(i + 1) * values_per_miniblock], + ); + } } } -fn write_miniblock(buffer: &mut Vec, num_bits: usize, deltas: [u64; 128]) { +fn write_miniblock(buffer: &mut Vec, num_bits: u8, deltas: &[u64]) { + let num_bits = num_bits as usize; if num_bits > 0 { let start = buffer.len(); // bitpack encode all (deltas.len = 128 which is a multiple of 32) let bytes_needed = start + ceil8(deltas.len() * num_bits); buffer.resize(bytes_needed, 0); - bitpacked::encode(deltas.as_ref(), num_bits, &mut buffer[start..]); + bitpacked::encode(deltas, num_bits, &mut buffer[start..]); let bytes_needed = start + ceil8(deltas.len() * num_bits); buffer.truncate(bytes_needed); @@ -80,8 +104,8 @@ mod tests { #[test] fn constant_delta() { - // header: [128, 1, 1, 5, 2]: - // block size: 128 <=u> 128, 1 + // header: [128, 2, 1, 5, 2]: + // block size: 256 <=u> 128, 2 // mini-blocks: 1 <=u> 1 // elements: 5 <=u> 5 // first_value: 2 <=z> 1 @@ -89,10 +113,10 @@ mod tests { // min_delta: 1 <=z> 2 // bitwidth: 0 let data = 1..=5; - let expected = vec![128u8, 1, 1, 5, 2, 2, 0]; + let expected = vec![128u8, 2, 1, 5, 2, 2, 0]; let mut buffer = vec![]; - encode(data, &mut buffer); + encode(data.collect::>().into_iter(), &mut buffer, 1); assert_eq!(expected, buffer); } @@ -100,8 +124,8 @@ mod tests { fn negative_min_delta() { // max - min = 1 - -4 = 5 let data = vec![1, 2, 3, 4, 5, 1]; - // header: [128, 1, 4, 6, 2] - // block size: 128 <=u> 128, 1 + // header: [128, 2, 4, 6, 2] + // block size: 256 <=u> 128, 2 // mini-blocks: 1 <=u> 1 // elements: 6 <=u> 5 // first_value: 2 <=z> 1 @@ -112,11 +136,11 @@ mod tests { // 0b01101101 // 0b00001011 // ] - let mut expected = vec![128u8, 1, 1, 6, 2, 7, 3, 0b01101101, 0b00001011]; - expected.extend(std::iter::repeat(0).take(128 * 3 / 8 - 2)); // 128 values, 3 bits, 2 already used + let mut expected = vec![128u8, 2, 1, 6, 2, 7, 3, 0b01101101, 0b00001011]; + expected.extend(std::iter::repeat(0).take(256 * 3 / 8 - 2)); // 128 values, 3 bits, 2 already used let mut buffer = vec![]; - encode(data.into_iter(), &mut buffer); + encode(data.into_iter(), &mut buffer, 1); assert_eq!(expected, buffer); } } diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/fuzz.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/fuzz.rs new file mode 100644 index 000000000000..dc16bc8353fd --- /dev/null +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/fuzz.rs @@ -0,0 +1,76 @@ +#[ignore = "Fuzz test. Takes too long"] +#[test] +fn fuzz_test_delta_encoding() -> Result<(), Box> { + use rand::Rng; + + use super::DeltaGatherer; + use crate::parquet::error::ParquetResult; + + struct SimpleGatherer; + + impl DeltaGatherer for SimpleGatherer { + type Target = Vec; + + fn target_len(&self, target: &Self::Target) -> usize { + target.len() + } + + fn target_reserve(&self, target: &mut Self::Target, n: usize) { + target.reserve(n); + } + + fn gather_one(&mut self, target: &mut Self::Target, v: i64) -> ParquetResult<()> { + target.push(v); + Ok(()) + } + } + + const MIN_VALUES: usize = 1; + const MAX_VALUES: usize = 515; + + const MIN: i64 = i64::MIN; + const MAX: i64 = i64::MAX; + + const NUM_ITERATIONS: usize = 1_000_000; + + let mut values = Vec::with_capacity(MAX_VALUES); + let mut rng = rand::thread_rng(); + + let mut encoded = Vec::with_capacity(MAX_VALUES); + let mut decoded = Vec::with_capacity(MAX_VALUES); + let mut gatherer = SimpleGatherer; + + for i in 0..NUM_ITERATIONS { + values.clear(); + + let num_values = rng.gen_range(MIN_VALUES..=MAX_VALUES); + values.extend(std::iter::from_fn(|| Some(rng.gen_range(MIN..=MAX))).take(num_values)); + + encoded.clear(); + decoded.clear(); + + super::encode( + values.iter().copied(), + &mut encoded, + 1 << rng.gen_range(0..=2), + ); + let (mut decoder, rem) = super::Decoder::try_new(&encoded)?; + + assert!(rem.is_empty()); + + let mut num_remaining = num_values; + while num_remaining > 0 { + let n = rng.gen_range(1usize..=num_remaining); + decoder.gather_n_into(&mut decoded, n, &mut gatherer)?; + num_remaining -= n; + } + + assert_eq!(values, decoded); + + if i % 1000 == 999 { + eprintln!("[INFO]: {} iterations done.", i + 1); + } + } + + Ok(()) +} diff --git a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs index 4f7922821c5f..4a32610a302e 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_bitpacked/mod.rs @@ -1,23 +1,52 @@ mod decoder; mod encoder; - -pub use decoder::Decoder; -pub use encoder::encode; +mod fuzz; + +pub(crate) use decoder::{Decoder, DeltaGatherer, SumGatherer}; +pub(crate) use encoder::encode; + +/// The sum of `start, start + delta, start + 2 * delta, ... len times`. +pub(crate) fn lin_natural_sum(start: i64, delta: i64, len: usize) -> i64 { + debug_assert!(len < i64::MAX as usize); + + let base = start * len as i64; + let sum = if len == 0 { + 0 + } else { + let is_odd = len & 1; + // SUM_i=0^n f * i = f * (n(n+1)/2) + let sum = (len >> (is_odd ^ 1)) * (len.wrapping_sub(1) >> is_odd); + delta * sum as i64 + }; + + base + sum +} #[cfg(test)] mod tests { use super::*; - use crate::parquet::error::ParquetError; + use crate::parquet::error::{ParquetError, ParquetResult}; + + #[test] + fn linear_natural_sum() { + assert_eq!(lin_natural_sum(0, 0, 0), 0); + assert_eq!(lin_natural_sum(10, 4, 0), 0); + assert_eq!(lin_natural_sum(0, 1, 1), 0); + assert_eq!(lin_natural_sum(0, 1, 3), 3); + assert_eq!(lin_natural_sum(0, 1, 4), 6); + assert_eq!(lin_natural_sum(0, 2, 3), 6); + assert_eq!(lin_natural_sum(2, 2, 3), 12); + } #[test] fn basic() -> Result<(), ParquetError> { let data = vec![1, 3, 1, 2, 3]; let mut buffer = vec![]; - encode(data.clone().into_iter(), &mut buffer); - let iter = Decoder::try_new(&buffer)?; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; - let result = iter.collect::, _>>()?; + let result = iter.collect::>()?; assert_eq!(result, data); Ok(()) } @@ -27,10 +56,10 @@ mod tests { let data = vec![1, 3, -1, 2, 3]; let mut buffer = vec![]; - encode(data.clone().into_iter(), &mut buffer); - let iter = Decoder::try_new(&buffer)?; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; - let result = iter.collect::, _>>()?; + let result = iter.collect::>()?; assert_eq!(result, data); Ok(()) } @@ -48,10 +77,10 @@ mod tests { ]; let mut buffer = vec![]; - encode(data.clone().into_iter(), &mut buffer); - let iter = Decoder::try_new(&buffer)?; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; - let result = iter.collect::, ParquetError>>()?; + let result = iter.collect::>()?; assert_eq!(result, data); Ok(()) } @@ -64,10 +93,10 @@ mod tests { } let mut buffer = vec![]; - encode(data.clone().into_iter(), &mut buffer); - let iter = Decoder::try_new(&buffer)?; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; - let result = iter.collect::, _>>()?; + let result = iter.collect::>()?; assert_eq!(result, data); Ok(()) } @@ -77,14 +106,47 @@ mod tests { let data = vec![2, 3, 1, 2, 1]; let mut buffer = vec![]; - encode(data.clone().into_iter(), &mut buffer); - let len = buffer.len(); - let mut iter = Decoder::try_new(&buffer)?; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + + Ok(()) + } + + #[test] + fn overflow_constant() -> ParquetResult<()> { + let data = vec![i64::MIN, i64::MAX, i64::MIN, i64::MAX]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; + + let result = iter.collect::>()?; + assert_eq!(result, data); + + Ok(()) + } + + #[test] + fn overflow_vary() -> ParquetResult<()> { + let data = vec![ + 0, + i64::MAX, + i64::MAX - 1, + i64::MIN + 1, + i64::MAX, + i64::MIN + 2, + ]; + + let mut buffer = vec![]; + encode(data.clone().into_iter(), &mut buffer, 1); + let (iter, _) = Decoder::try_new(&buffer)?; - let result = iter.by_ref().collect::, _>>()?; + let result = iter.collect::>()?; assert_eq!(result, data); - assert_eq!(iter.consumed_bytes(), len); Ok(()) } } diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs index 9196eaedb7c8..03889e0aa5d3 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/decoder.rs @@ -1,5 +1,6 @@ -use super::super::{delta_bitpacked, delta_length_byte_array}; -use crate::parquet::error::ParquetError; +use super::super::delta_bitpacked; +use crate::parquet::encoding::delta_bitpacked::SumGatherer; +use crate::parquet::error::ParquetResult; /// Decodes according to [Delta strings](https://github.com/apache/parquet-format/blob/master/Encodings.md#delta-strings-delta_byte_array--7), /// prefixes, lengths and values @@ -7,32 +8,47 @@ use crate::parquet::error::ParquetError; /// This struct does not allocate on the heap. #[derive(Debug)] pub struct Decoder<'a> { - values: &'a [u8], - prefix_lengths: delta_bitpacked::Decoder<'a>, + pub(crate) prefix_lengths: delta_bitpacked::Decoder<'a>, + pub(crate) suffix_lengths: delta_bitpacked::Decoder<'a>, + pub(crate) values: &'a [u8], + + pub(crate) offset: usize, + pub(crate) last: Vec, } impl<'a> Decoder<'a> { - pub fn try_new(values: &'a [u8]) -> Result { - let prefix_lengths = delta_bitpacked::Decoder::try_new(values)?; + pub fn try_new(values: &'a [u8]) -> ParquetResult { + let (prefix_lengths, values) = delta_bitpacked::Decoder::try_new(values)?; + let (suffix_lengths, values) = delta_bitpacked::Decoder::try_new(values)?; + Ok(Self { - values, prefix_lengths, + suffix_lengths, + values, + + offset: 0, + last: Vec::with_capacity(32), }) } - pub fn into_lengths(self) -> Result, ParquetError> { - assert_eq!(self.prefix_lengths.size_hint().0, 0); - delta_length_byte_array::Decoder::try_new( - &self.values[self.prefix_lengths.consumed_bytes()..], - ) + pub fn values(&self) -> &'a [u8] { + self.values } -} -impl<'a> Iterator for Decoder<'a> { - type Item = Result; + pub fn len(&self) -> usize { + debug_assert_eq!(self.prefix_lengths.len(), self.suffix_lengths.len()); + self.prefix_lengths.len() + } - fn next(&mut self) -> Option { - self.prefix_lengths.next().map(|x| x.map(|x| x as u32)) + pub fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + let mut prefix_sum = 0usize; + self.prefix_lengths + .gather_n_into(&mut prefix_sum, n, &mut SumGatherer(0))?; + let mut suffix_sum = 0usize; + self.suffix_lengths + .gather_n_into(&mut suffix_sum, n, &mut SumGatherer(0))?; + self.offset += prefix_sum + suffix_sum; + Ok(()) } } @@ -40,8 +56,44 @@ impl<'a> Iterator for Decoder<'a> { mod tests { use super::*; + impl<'a> Iterator for Decoder<'a> { + type Item = ParquetResult>; + + fn next(&mut self) -> Option { + if self.len() == 0 { + return None; + } + + let mut prefix_length = vec![]; + let mut suffix_length = vec![]; + if let Err(e) = self.prefix_lengths.collect_n(&mut prefix_length, 1) { + return Some(Err(e)); + } + if let Err(e) = self.suffix_lengths.collect_n(&mut suffix_length, 1) { + return Some(Err(e)); + } + let prefix_length = prefix_length[0]; + let suffix_length = suffix_length[0]; + + let prefix_length = prefix_length as usize; + let suffix_length = suffix_length as usize; + + let mut value = Vec::with_capacity(prefix_length + suffix_length); + + value.extend_from_slice(&self.last[..prefix_length]); + value.extend_from_slice(&self.values[self.offset..self.offset + suffix_length]); + + self.last.clear(); + self.last.extend_from_slice(&value); + + self.offset += suffix_length; + + Some(Ok(value)) + } + } + #[test] - fn test_bla() -> Result<(), ParquetError> { + fn test_bla() -> ParquetResult<()> { // VALIDATED from spark==3.1.1 let data = &[ 128, 1, 4, 2, 0, 0, 0, 0, 0, 0, 128, 1, 4, 2, 10, 0, 0, 0, 0, 0, 72, 101, 108, 108, @@ -50,31 +102,16 @@ mod tests { // because they are beyond the sum of all lengths. 1, 2, 3, ]; - // result of encoding - let expected = &["Hello", "World"]; - let expected_lengths = expected.iter().map(|x| x.len() as i32).collect::>(); - let expected_prefixes = vec![0, 0]; - let expected_values = expected.join(""); - let expected_values = expected_values.as_bytes(); - - let mut decoder = Decoder::try_new(data)?; - let prefixes = decoder.by_ref().collect::, _>>()?; - assert_eq!(prefixes, expected_prefixes); - - // move to the lengths - let mut decoder = decoder.into_lengths()?; - - let lengths = decoder.by_ref().collect::, _>>()?; - assert_eq!(lengths, expected_lengths); - - // move to the values - let values = decoder.values(); - assert_eq!(values, expected_values); + + let decoder = Decoder::try_new(data)?; + let values = decoder.collect::, _>>()?; + assert_eq!(values, vec![b"Hello".to_vec(), b"World".to_vec()]); + Ok(()) } #[test] - fn test_with_prefix() -> Result<(), ParquetError> { + fn test_with_prefix() -> ParquetResult<()> { // VALIDATED from spark==3.1.1 let data = &[ 128, 1, 4, 2, 0, 6, 0, 0, 0, 0, 128, 1, 4, 2, 10, 4, 0, 0, 0, 0, 72, 101, 108, 108, @@ -83,24 +120,11 @@ mod tests { // because they are beyond the sum of all lengths. 1, 2, 3, ]; - // result of encoding - let expected_lengths = vec![5, 7]; - let expected_prefixes = vec![0, 3]; - let expected_values = b"Helloicopter"; - - let mut decoder = Decoder::try_new(data)?; - let prefixes = decoder.by_ref().collect::, _>>()?; - assert_eq!(prefixes, expected_prefixes); - - // move to the lengths - let mut decoder = decoder.into_lengths()?; - let lengths = decoder.by_ref().collect::, _>>()?; - assert_eq!(lengths, expected_lengths); + let decoder = Decoder::try_new(data)?; + let prefixes = decoder.collect::, _>>()?; + assert_eq!(prefixes, vec![b"Hello".to_vec(), b"Helicopter".to_vec()]); - // move to the values - let values = decoder.values(); - assert_eq!(values, expected_values); Ok(()) } } diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs index 1e9e071c87be..3a36e90b9966 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/encoder.rs @@ -2,7 +2,10 @@ use super::super::delta_bitpacked; use crate::parquet::encoding::delta_length_byte_array; /// Encodes an iterator of according to DELTA_BYTE_ARRAY -pub fn encode<'a, I: Iterator + Clone>(iterator: I, buffer: &mut Vec) { +pub fn encode<'a, I: ExactSizeIterator + Clone>( + iterator: I, + buffer: &mut Vec, +) { let mut previous = b"".as_ref(); let mut sum_lengths = 0; @@ -22,7 +25,7 @@ pub fn encode<'a, I: Iterator + Clone>(iterator: I, buffer: &mu prefix_length as i64 }) .collect::>(); - delta_bitpacked::encode(prefixes.iter().copied(), buffer); + delta_bitpacked::encode(prefixes.iter().copied(), buffer, 1); let remaining = iterator .zip(prefixes) diff --git a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs index b5927ab95b58..2bb51511d67e 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_byte_array/mod.rs @@ -17,13 +17,7 @@ mod tests { let mut decoder = Decoder::try_new(&buffer)?; let prefixes = decoder.by_ref().collect::, _>>()?; - assert_eq!(prefixes, vec![0, 3]); - - // move to the lengths - let mut decoder = decoder.into_lengths()?; - - let lengths = decoder.by_ref().collect::, _>>()?; - assert_eq!(lengths, vec![5, 7]); + assert_eq!(prefixes, vec![b"Hello".to_vec(), b"Helicopter".to_vec()]); // move to the values let values = decoder.values(); diff --git a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs index bd9a77a00add..b3191e0a51ff 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/decoder.rs @@ -1,80 +1,57 @@ use super::super::delta_bitpacked; -use crate::parquet::error::ParquetError; +use crate::parquet::encoding::delta_bitpacked::SumGatherer; +use crate::parquet::error::ParquetResult; /// Decodes [Delta-length byte array](https://github.com/apache/parquet-format/blob/master/Encodings.md#delta-length-byte-array-delta_length_byte_array--6) /// lengths and values. /// # Implementation /// This struct does not allocate on the heap. -/// # Example -/// ``` -/// use polars_parquet::parquet::encoding::delta_length_byte_array::Decoder; -/// -/// let expected = &["Hello", "World"]; -/// let expected_lengths = expected.iter().map(|x| x.len() as i32).collect::>(); -/// let expected_values = expected.join(""); -/// let expected_values = expected_values.as_bytes(); -/// let data = &[ -/// 128, 1, 4, 2, 10, 0, 0, 0, 0, 0, 72, 101, 108, 108, 111, 87, 111, 114, 108, 100, -/// ]; -/// -/// let mut decoder = Decoder::try_new(data).unwrap(); -/// -/// // Extract the lengths -/// let lengths = decoder.by_ref().collect::, _>>().unwrap(); -/// assert_eq!(lengths, expected_lengths); -/// -/// // Extract the values. This _must_ be called after consuming all lengths by reference (see above). -/// let values = decoder.into_values(); -/// -/// assert_eq!(values, expected_values); #[derive(Debug)] -pub struct Decoder<'a> { - values: &'a [u8], - lengths: delta_bitpacked::Decoder<'a>, - total_length: u32, +pub(crate) struct Decoder<'a> { + pub(crate) lengths: delta_bitpacked::Decoder<'a>, + pub(crate) values: &'a [u8], + pub(crate) offset: usize, } impl<'a> Decoder<'a> { - pub fn try_new(values: &'a [u8]) -> Result { - let lengths = delta_bitpacked::Decoder::try_new(values)?; + pub fn try_new(values: &'a [u8]) -> ParquetResult { + let (lengths, values) = delta_bitpacked::Decoder::try_new(values)?; Ok(Self { - values, lengths, - total_length: 0, + values, + offset: 0, }) } - /// Consumes this decoder and returns the slice of concatenated values. - /// # Panics - /// This function panics if this iterator has not been fully consumed. - pub fn into_values(self) -> &'a [u8] { - assert_eq!(self.lengths.size_hint().0, 0); - let start = self.lengths.consumed_bytes(); - &self.values[start..start + self.total_length as usize] + pub(crate) fn skip_in_place(&mut self, n: usize) -> ParquetResult<()> { + let mut sum = 0usize; + self.lengths + .gather_n_into(&mut sum, n, &mut SumGatherer(0))?; + self.offset += sum; + Ok(()) } - /// Returns the slice of concatenated values. - /// # Panics - /// This function panics if this iterator has not yet been fully consumed. - pub fn values(&self) -> &'a [u8] { - assert_eq!(self.lengths.size_hint().0, 0); - let start = self.lengths.consumed_bytes(); - &self.values[start..start + self.total_length as usize] + pub fn len(&self) -> usize { + self.lengths.len() } } +#[cfg(test)] impl<'a> Iterator for Decoder<'a> { - type Item = Result; + type Item = ParquetResult<&'a [u8]>; fn next(&mut self) -> Option { - let result = self.lengths.next(); - match result { - Some(Ok(v)) => { - self.total_length += v as u32; - Some(Ok(v as i32)) - }, - Some(Err(error)) => Some(Err(error)), - None => None, + if self.lengths.len() == 0 { + return None; + } + + let mut length = vec![]; + if let Err(e) = self.lengths.collect_n(&mut length, 1) { + return Some(Err(e)); } + let length = length[0] as usize; + let value = &self.values[self.offset..self.offset + length]; + self.offset += length; + Some(Ok(value)) } } diff --git a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs index fc2121cf68e8..d768b10c24f3 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/encoder.rs @@ -4,7 +4,10 @@ use crate::parquet::encoding::delta_bitpacked; /// # Implementation /// This encoding is equivalent to call [`delta_bitpacked::encode`] on the lengths of the items /// of the iterator followed by extending the buffer from each item of the iterator. -pub fn encode, I: Iterator + Clone>(iterator: I, buffer: &mut Vec) { +pub fn encode, I: ExactSizeIterator + Clone>( + iterator: I, + buffer: &mut Vec, +) { let mut total_length = 0; delta_bitpacked::encode( iterator.clone().map(|x| { @@ -13,6 +16,7 @@ pub fn encode, I: Iterator + Clone>(iterator: I, buffer len as i64 }), buffer, + 1, ); buffer.reserve(total_length); iterator.for_each(|x| buffer.extend(x.as_ref())) diff --git a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs index 35b5bd9fd5fb..050ac766f545 100644 --- a/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/delta_length_byte_array/mod.rs @@ -1,8 +1,8 @@ mod decoder; mod encoder; -pub use decoder::Decoder; -pub use encoder::encode; +pub(crate) use decoder::Decoder; +pub(crate) use encoder::encode; #[cfg(test)] mod tests { @@ -19,9 +19,18 @@ mod tests { let mut iter = Decoder::try_new(&buffer)?; let result = iter.by_ref().collect::, _>>()?; - assert_eq!(result, vec![2, 3, 1, 2, 1]); - - let result = iter.values(); + assert_eq!( + result, + vec![ + b"aa".as_ref(), + b"bbb".as_ref(), + b"a".as_ref(), + b"aa".as_ref(), + b"b".as_ref() + ] + ); + + let result = iter.values; assert_eq!(result, b"aabbbaaab".as_ref()); Ok(()) } @@ -32,8 +41,11 @@ mod tests { for i in 0..136 { data.push(format!("a{}", i)) } - let expected_values = data.join(""); - let expected_lengths = data.iter().map(|x| x.len() as i32).collect::>(); + + let expected = data + .iter() + .map(|v| v.as_bytes().to_vec()) + .collect::>(); let mut buffer = vec![]; encode(data.into_iter(), &mut buffer); @@ -41,10 +53,8 @@ mod tests { let mut iter = Decoder::try_new(&buffer)?; let result = iter.by_ref().collect::, _>>()?; - assert_eq!(result, expected_lengths); + assert_eq!(result, expected); - let result = iter.into_values(); - assert_eq!(result, expected_values.as_bytes()); Ok(()) } } diff --git a/crates/polars-parquet/src/parquet/encoding/uleb128.rs b/crates/polars-parquet/src/parquet/encoding/uleb128.rs index 08459233961c..0740c9575a15 100644 --- a/crates/polars-parquet/src/parquet/encoding/uleb128.rs +++ b/crates/polars-parquet/src/parquet/encoding/uleb128.rs @@ -1,5 +1,6 @@ // Reads an uleb128 encoded integer with at most 56 bits (8 bytes with 7 bits worth of payload each). /// Returns the integer and the number of bytes that made up this integer. +/// /// If the returned length is bigger than 8 this means the integer required more than 8 bytes and the remaining bytes need to be read sequentially and combined with the return value. /// /// # Safety diff --git a/crates/polars-parquet/src/parquet/indexes/index.rs b/crates/polars-parquet/src/parquet/indexes/index.rs deleted file mode 100644 index ecf11fe7f30e..000000000000 --- a/crates/polars-parquet/src/parquet/indexes/index.rs +++ /dev/null @@ -1,322 +0,0 @@ -use std::any::Any; - -use parquet_format_safe::ColumnIndex; - -use crate::parquet::error::ParquetError; -use crate::parquet::parquet_bridge::BoundaryOrder; -use crate::parquet::schema::types::{PhysicalType, PrimitiveType}; -use crate::parquet::types::NativeType; - -/// Trait object representing a [`ColumnIndex`] in Rust's native format. -/// -/// See [`NativeIndex`], [`ByteIndex`] and [`FixedLenByteIndex`] for concrete implementations. -pub trait Index: Send + Sync + std::fmt::Debug { - fn as_any(&self) -> &dyn Any; - - fn physical_type(&self) -> &PhysicalType; -} - -impl PartialEq for dyn Index + '_ { - fn eq(&self, that: &dyn Index) -> bool { - equal(self, that) - } -} - -impl Eq for dyn Index + '_ {} - -fn equal(lhs: &dyn Index, rhs: &dyn Index) -> bool { - if lhs.physical_type() != rhs.physical_type() { - return false; - } - - match lhs.physical_type() { - PhysicalType::Boolean => { - lhs.as_any().downcast_ref::().unwrap() - == rhs.as_any().downcast_ref::().unwrap() - }, - PhysicalType::Int32 => { - lhs.as_any().downcast_ref::>().unwrap() - == rhs.as_any().downcast_ref::>().unwrap() - }, - PhysicalType::Int64 => { - lhs.as_any().downcast_ref::>().unwrap() - == rhs.as_any().downcast_ref::>().unwrap() - }, - PhysicalType::Int96 => { - lhs.as_any() - .downcast_ref::>() - .unwrap() - == rhs - .as_any() - .downcast_ref::>() - .unwrap() - }, - PhysicalType::Float => { - lhs.as_any().downcast_ref::>().unwrap() - == rhs.as_any().downcast_ref::>().unwrap() - }, - PhysicalType::Double => { - lhs.as_any().downcast_ref::>().unwrap() - == rhs.as_any().downcast_ref::>().unwrap() - }, - PhysicalType::ByteArray => { - lhs.as_any().downcast_ref::().unwrap() - == rhs.as_any().downcast_ref::().unwrap() - }, - PhysicalType::FixedLenByteArray(_) => { - lhs.as_any().downcast_ref::().unwrap() - == rhs.as_any().downcast_ref::().unwrap() - }, - } -} - -/// An index of a column of [`NativeType`] physical representation -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct NativeIndex { - /// The primitive type - pub primitive_type: PrimitiveType, - /// The indexes, one item per page - pub indexes: Vec>, - /// the order - pub boundary_order: BoundaryOrder, -} - -impl NativeIndex { - /// Creates a new [`NativeIndex`] - pub(crate) fn try_new( - index: ColumnIndex, - primitive_type: PrimitiveType, - ) -> Result { - let len = index.min_values.len(); - - let null_counts = index - .null_counts - .map(|x| x.into_iter().map(Some).collect::>()) - .unwrap_or_else(|| vec![None; len]); - - let indexes = index - .min_values - .iter() - .zip(index.max_values.into_iter()) - .zip(index.null_pages.into_iter()) - .zip(null_counts.into_iter()) - .map(|(((min, max), is_null), null_count)| { - let (min, max) = if is_null { - (None, None) - } else { - let min = min.as_slice().try_into()?; - let max = max.as_slice().try_into()?; - (Some(T::from_le_bytes(min)), Some(T::from_le_bytes(max))) - }; - Ok(PageIndex { - min, - max, - null_count, - }) - }) - .collect::, ParquetError>>()?; - - Ok(Self { - primitive_type, - indexes, - boundary_order: index.boundary_order.try_into()?, - }) - } -} - -/// The index of a page, containing the min and max values of the page. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct PageIndex { - /// The minimum value in the page. It is None when all values are null - pub min: Option, - /// The maximum value in the page. It is None when all values are null - pub max: Option, - /// The number of null values in the page - pub null_count: Option, -} - -impl Index for NativeIndex { - fn as_any(&self) -> &dyn Any { - self - } - - fn physical_type(&self) -> &PhysicalType { - &T::TYPE - } -} - -/// An index of a column of bytes physical type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ByteIndex { - /// The [`PrimitiveType`]. - pub primitive_type: PrimitiveType, - /// The indexes, one item per page - pub indexes: Vec>>, - pub boundary_order: BoundaryOrder, -} - -impl ByteIndex { - pub(crate) fn try_new( - index: ColumnIndex, - primitive_type: PrimitiveType, - ) -> Result { - let len = index.min_values.len(); - - let null_counts = index - .null_counts - .map(|x| x.into_iter().map(Some).collect::>()) - .unwrap_or_else(|| vec![None; len]); - - let indexes = index - .min_values - .into_iter() - .zip(index.max_values.into_iter()) - .zip(index.null_pages.into_iter()) - .zip(null_counts.into_iter()) - .map(|(((min, max), is_null), null_count)| { - let (min, max) = if is_null { - (None, None) - } else { - (Some(min), Some(max)) - }; - Ok(PageIndex { - min, - max, - null_count, - }) - }) - .collect::, ParquetError>>()?; - - Ok(Self { - primitive_type, - indexes, - boundary_order: index.boundary_order.try_into()?, - }) - } -} - -impl Index for ByteIndex { - fn as_any(&self) -> &dyn Any { - self - } - - fn physical_type(&self) -> &PhysicalType { - &PhysicalType::ByteArray - } -} - -/// An index of a column of fixed len byte physical type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct FixedLenByteIndex { - /// The [`PrimitiveType`]. - pub primitive_type: PrimitiveType, - /// The indexes, one item per page - pub indexes: Vec>>, - pub boundary_order: BoundaryOrder, -} - -impl FixedLenByteIndex { - pub(crate) fn try_new( - index: ColumnIndex, - primitive_type: PrimitiveType, - ) -> Result { - let len = index.min_values.len(); - - let null_counts = index - .null_counts - .map(|x| x.into_iter().map(Some).collect::>()) - .unwrap_or_else(|| vec![None; len]); - - let indexes = index - .min_values - .into_iter() - .zip(index.max_values.into_iter()) - .zip(index.null_pages.into_iter()) - .zip(null_counts.into_iter()) - .map(|(((min, max), is_null), null_count)| { - let (min, max) = if is_null { - (None, None) - } else { - (Some(min), Some(max)) - }; - Ok(PageIndex { - min, - max, - null_count, - }) - }) - .collect::, ParquetError>>()?; - - Ok(Self { - primitive_type, - indexes, - boundary_order: index.boundary_order.try_into()?, - }) - } -} - -impl Index for FixedLenByteIndex { - fn as_any(&self) -> &dyn Any { - self - } - - fn physical_type(&self) -> &PhysicalType { - &self.primitive_type.physical_type - } -} - -/// An index of a column of boolean physical type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct BooleanIndex { - /// The indexes, one item per page - pub indexes: Vec>, - pub boundary_order: BoundaryOrder, -} - -impl BooleanIndex { - pub(crate) fn try_new(index: ColumnIndex) -> Result { - let len = index.min_values.len(); - - let null_counts = index - .null_counts - .map(|x| x.into_iter().map(Some).collect::>()) - .unwrap_or_else(|| vec![None; len]); - - let indexes = index - .min_values - .into_iter() - .zip(index.max_values.into_iter()) - .zip(index.null_pages.into_iter()) - .zip(null_counts.into_iter()) - .map(|(((min, max), is_null), null_count)| { - let (min, max) = if is_null { - (None, None) - } else { - let min = min[0] == 1; - let max = max[0] == 1; - (Some(min), Some(max)) - }; - Ok(PageIndex { - min, - max, - null_count, - }) - }) - .collect::, ParquetError>>()?; - - Ok(Self { - indexes, - boundary_order: index.boundary_order.try_into()?, - }) - } -} - -impl Index for BooleanIndex { - fn as_any(&self) -> &dyn Any { - self - } - - fn physical_type(&self) -> &PhysicalType { - &PhysicalType::Boolean - } -} diff --git a/crates/polars-parquet/src/parquet/indexes/intervals.rs b/crates/polars-parquet/src/parquet/indexes/intervals.rs deleted file mode 100644 index d04d3104a618..000000000000 --- a/crates/polars-parquet/src/parquet/indexes/intervals.rs +++ /dev/null @@ -1,139 +0,0 @@ -use parquet_format_safe::PageLocation; -#[cfg(feature = "serde_types")] -use serde::{Deserialize, Serialize}; - -use crate::parquet::error::ParquetError; - -/// An interval -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] -pub struct Interval { - /// Its start - pub start: usize, - /// Its length - pub length: usize, -} - -impl Interval { - /// Create a new interval - pub fn new(start: usize, length: usize) -> Self { - Self { start, length } - } -} - -/// Returns the set of (row) intervals of the pages. -/// # Errors -/// This function errors if the locations are not castable to `usize` or such that -/// their ranges of row are larger than `num_rows`. -pub fn compute_page_row_intervals( - locations: &[PageLocation], - num_rows: usize, -) -> Result, ParquetError> { - if locations.is_empty() { - return Ok(vec![]); - }; - - let last = (|| { - let start: usize = locations.last().unwrap().first_row_index.try_into()?; - let length = num_rows.checked_sub(start).ok_or_else(|| { - ParquetError::oos("Page start cannot be smaller than the number of rows") - })?; - Result::<_, ParquetError>::Ok(Interval::new(start, length)) - })(); - - let pages_lengths = locations - .windows(2) - .map(|x| { - let start = x[0].first_row_index.try_into()?; - - let length = x[1] - .first_row_index - .checked_sub(x[0].first_row_index) - .ok_or_else(|| { - ParquetError::oos("Page start cannot be smaller than the number of rows") - })? - .try_into()?; - - Ok(Interval::new(start, length)) - }) - .chain(std::iter::once(last)); - pages_lengths.collect() -} - -/// Returns the set of intervals `(start, len)` containing all the -/// selected rows (for a given column) -pub fn compute_rows( - selected: &[bool], - locations: &[PageLocation], - num_rows: usize, -) -> Result, ParquetError> { - let page_intervals = compute_page_row_intervals(locations, num_rows)?; - - Ok(selected - .iter() - .zip(page_intervals.iter().copied()) - .filter_map( - |(&is_selected, page)| { - if is_selected { - Some(page) - } else { - None - } - }, - ) - .collect()) -} - -/// An enum describing a page that was either selected in a filter pushdown or skipped -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] -pub struct FilteredPage { - /// Location of the page in the file - pub start: u64, - pub length: usize, - /// rows to select from the page - pub selected_rows: Vec, - pub num_rows: usize, -} - -fn is_in(probe: Interval, intervals: &[Interval]) -> Vec { - intervals - .iter() - .filter_map(|interval| { - let interval_end = interval.start + interval.length; - let probe_end = probe.start + probe.length; - let overlaps = (probe.start < interval_end) && (probe_end > interval.start); - if overlaps { - let start = interval.start.max(probe.start); - let end = interval_end.min(probe_end); - Some(Interval::new(start - probe.start, end - start)) - } else { - None - } - }) - .collect() -} - -/// Given a set of selected [Interval]s of rows and the set of [`PageLocation`], returns the -/// a set of [`FilteredPage`] with the same number of items as `locations`. -pub fn select_pages( - intervals: &[Interval], - locations: &[PageLocation], - num_rows: usize, -) -> Result, ParquetError> { - let page_intervals = compute_page_row_intervals(locations, num_rows)?; - - page_intervals - .into_iter() - .zip(locations.iter()) - .map(|(interval, location)| { - let selected_rows = is_in(interval, intervals); - Ok(FilteredPage { - start: location.offset.try_into()?, - length: location.compressed_page_size.try_into()?, - selected_rows, - num_rows: interval.length, - }) - }) - .collect() -} diff --git a/crates/polars-parquet/src/parquet/indexes/mod.rs b/crates/polars-parquet/src/parquet/indexes/mod.rs deleted file mode 100644 index f652f8bb4be3..000000000000 --- a/crates/polars-parquet/src/parquet/indexes/mod.rs +++ /dev/null @@ -1,234 +0,0 @@ -mod index; -mod intervals; - -pub use intervals::{compute_rows, select_pages, FilteredPage, Interval}; - -pub use self::index::{BooleanIndex, ByteIndex, FixedLenByteIndex, Index, NativeIndex, PageIndex}; -pub use crate::parquet::parquet_bridge::BoundaryOrder; -pub use crate::parquet::thrift_format::PageLocation; - -#[cfg(test)] -mod tests { - use super::*; - use crate::parquet::schema::types::{PhysicalType, PrimitiveType}; - - #[test] - fn test_basic() { - let locations = &[PageLocation { - offset: 100, - compressed_page_size: 10, - first_row_index: 0, - }]; - let num_rows = 10; - - let row_intervals = compute_rows(&[true; 1], locations, num_rows).unwrap(); - assert_eq!(row_intervals, vec![Interval::new(0, 10)]) - } - - #[test] - fn test_multiple() { - // two pages - let index = ByteIndex { - primitive_type: PrimitiveType::from_physical("c1".to_string(), PhysicalType::ByteArray), - indexes: vec![ - PageIndex { - min: Some(vec![0]), - max: Some(vec![8, 9]), - null_count: Some(0), - }, - PageIndex { - min: Some(vec![20]), - max: Some(vec![98, 99]), - null_count: Some(0), - }, - ], - boundary_order: Default::default(), - }; - let locations = &[ - PageLocation { - offset: 100, - compressed_page_size: 10, - first_row_index: 0, - }, - PageLocation { - offset: 110, - compressed_page_size: 20, - first_row_index: 5, - }, - ]; - let num_rows = 10; - - // filter of the form `x > "a"` - let selector = |page: &PageIndex>| { - page.max - .as_ref() - .map(|x| x.as_slice()[0] > 97) - .unwrap_or(false) // no max is present => all nulls => not selected - }; - let selected = index.indexes.iter().map(selector).collect::>(); - - let rows = compute_rows(&selected, locations, num_rows).unwrap(); - assert_eq!(rows, vec![Interval::new(5, 5)]); - - let pages = select_pages(&rows, locations, num_rows).unwrap(); - - assert_eq!( - pages, - vec![ - FilteredPage { - start: 100, - length: 10, - selected_rows: vec![], - num_rows: 5 - }, - FilteredPage { - start: 110, - length: 20, - selected_rows: vec![Interval::new(0, 5)], - num_rows: 5 - } - ] - ); - } - - #[test] - fn test_other_column() { - let locations = &[ - PageLocation { - offset: 100, - compressed_page_size: 20, - first_row_index: 0, - }, - PageLocation { - offset: 120, - compressed_page_size: 20, - first_row_index: 10, - }, - ]; - let num_rows = 100; - - let intervals = &[Interval::new(5, 5)]; - - let pages = select_pages(intervals, locations, num_rows).unwrap(); - - assert_eq!( - pages, - vec![ - FilteredPage { - start: 100, - length: 20, - selected_rows: vec![Interval::new(5, 5)], - num_rows: 10, - }, - FilteredPage { - start: 120, - length: 20, - selected_rows: vec![], - num_rows: 90 - }, - ] - ); - } - - #[test] - fn test_other_interval_in_middle() { - let locations = &[ - PageLocation { - offset: 100, - compressed_page_size: 20, - first_row_index: 0, - }, - PageLocation { - offset: 120, - compressed_page_size: 20, - first_row_index: 10, - }, - PageLocation { - offset: 140, - compressed_page_size: 20, - first_row_index: 100, - }, - ]; - let num_rows = 200; - - // interval partially intersects 2 pages (0 and 1) - let intervals = &[Interval::new(5, 6)]; - - let pages = select_pages(intervals, locations, num_rows).unwrap(); - - assert_eq!( - pages, - vec![ - FilteredPage { - start: 100, - length: 20, - selected_rows: vec![Interval::new(5, 5)], - num_rows: 10, - }, - FilteredPage { - start: 120, - length: 20, - selected_rows: vec![Interval::new(0, 1)], - num_rows: 90, - }, - FilteredPage { - start: 140, - length: 20, - selected_rows: vec![], - num_rows: 100 - }, - ] - ); - } - - #[test] - fn test_other_column2() { - let locations = &[ - PageLocation { - offset: 100, - compressed_page_size: 20, - first_row_index: 0, - }, - PageLocation { - offset: 120, - compressed_page_size: 20, - first_row_index: 10, - }, - PageLocation { - offset: 140, - compressed_page_size: 20, - first_row_index: 100, - }, - ]; - let num_rows = 200; - - // interval partially intersects 1 page (0) - let intervals = &[Interval::new(0, 1)]; - - let pages = select_pages(intervals, locations, num_rows).unwrap(); - - assert_eq!( - pages, - vec![ - FilteredPage { - start: 100, - length: 20, - selected_rows: vec![Interval::new(0, 1)], - num_rows: 10, - }, - FilteredPage { - start: 120, - length: 20, - selected_rows: vec![], - num_rows: 90 - }, - FilteredPage { - start: 140, - length: 20, - selected_rows: vec![], - num_rows: 100 - }, - ] - ); - } -} diff --git a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs index ac24bc8199ac..30a606d6108a 100644 --- a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs @@ -21,11 +21,14 @@ mod serde_types { use serde_types::*; /// Metadata for a column chunk. -// This contains the `ColumnDescriptor` associated with the chunk so that deserializers have -// access to the descriptor (e.g. physical, converted, logical). -#[derive(Debug, Clone)] +/// +/// This contains the `ColumnDescriptor` associated with the chunk so that deserializers have +/// access to the descriptor (e.g. physical, converted, logical). +/// +/// This struct is intentionally not `Clone`, as it is a huge struct. +#[derive(Debug)] #[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] -pub struct ColumnChunkMetaData { +pub struct ColumnChunkMetadata { #[cfg_attr( feature = "serde_types", serde(serialize_with = "serialize_column_chunk") @@ -67,8 +70,8 @@ where } // Represents common operations for a column chunk. -impl ColumnChunkMetaData { - /// Returns a new [`ColumnChunkMetaData`] +impl ColumnChunkMetadata { + /// Returns a new [`ColumnChunkMetadata`] pub fn new(column_chunk: ColumnChunk, column_descr: ColumnDescriptor) -> Self { Self { column_chunk, @@ -164,15 +167,9 @@ impl ColumnChunkMetaData { } /// Returns the offset and length in bytes of the column chunk within the file - pub fn byte_range(&self) -> (u64, u64) { - let start = if let Some(dict_page_offset) = self.dictionary_page_offset() { - dict_page_offset as u64 - } else { - self.data_page_offset() as u64 - }; - let length = self.compressed_size() as u64; + pub fn byte_range(&self) -> core::ops::Range { // this has been validated in [`try_from_thrift`] - (start, length) + column_metadata_byte_range(self.metadata()) } /// Method to convert from Thrift. @@ -205,3 +202,15 @@ impl ColumnChunkMetaData { self.column_chunk } } + +pub(super) fn column_metadata_byte_range( + column_metadata: &ColumnMetaData, +) -> core::ops::Range { + let offset = if let Some(dict_page_offset) = column_metadata.dictionary_page_offset { + dict_page_offset as u64 + } else { + column_metadata.data_page_offset as u64 + }; + let len = column_metadata.total_compressed_size as u64; + offset..offset.checked_add(len).unwrap() +} diff --git a/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs index 2c9a0d1f6e48..035ba32ad002 100644 --- a/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs +++ b/crates/polars-parquet/src/parquet/metadata/column_descriptor.rs @@ -1,3 +1,4 @@ +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; @@ -28,7 +29,7 @@ pub struct ColumnDescriptor { pub descriptor: Descriptor, /// The path of this column. For instance, "a.b.c.d". - pub path_in_schema: Vec, + pub path_in_schema: Vec, /// The [`ParquetType`] this descriptor is a leaf of pub base_type: ParquetType, @@ -38,7 +39,7 @@ impl ColumnDescriptor { /// Creates new descriptor for leaf-level column. pub fn new( descriptor: Descriptor, - path_in_schema: Vec, + path_in_schema: Vec, base_type: ParquetType, ) -> Self { Self { diff --git a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs index 7ae449c64d90..a7ffd6f7ba6d 100644 --- a/crates/polars-parquet/src/parquet/metadata/file_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/file_metadata.rs @@ -10,7 +10,7 @@ pub use crate::parquet::thrift_format::KeyValue; /// Metadata for a Parquet file. // This is almost equal to [`parquet_format_safe::FileMetaData`] but contains the descriptors, // which are crucial to deserialize pages. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FileMetaData { /// version of this file. pub version: i32, @@ -87,25 +87,6 @@ impl FileMetaData { column_orders, }) } - - /// Serializes itself to thrift's [`parquet_format_safe::FileMetaData`]. - pub fn into_thrift(self) -> parquet_format_safe::FileMetaData { - parquet_format_safe::FileMetaData { - version: self.version, - schema: self.schema_descr.into_thrift(), - num_rows: self.num_rows as i64, - row_groups: self - .row_groups - .into_iter() - .map(|v| v.into_thrift()) - .collect(), - key_value_metadata: self.key_value_metadata, - created_by: self.created_by, - column_orders: None, // todo - encryption_algorithm: None, - footer_signing_key_metadata: None, - } - } } /// Parses [`ColumnOrder`] from Thrift definition. diff --git a/crates/polars-parquet/src/parquet/metadata/mod.rs b/crates/polars-parquet/src/parquet/metadata/mod.rs index 2dfe81138fdd..c153cd7cf592 100644 --- a/crates/polars-parquet/src/parquet/metadata/mod.rs +++ b/crates/polars-parquet/src/parquet/metadata/mod.rs @@ -6,7 +6,7 @@ mod row_metadata; mod schema_descriptor; mod sort; -pub use column_chunk_metadata::ColumnChunkMetaData; +pub use column_chunk_metadata::ColumnChunkMetadata; pub use column_descriptor::{ColumnDescriptor, Descriptor}; pub use column_order::ColumnOrder; pub use file_metadata::{FileMetaData, KeyValue}; diff --git a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs index 54bf1d9ac718..717bc7e243d8 100644 --- a/crates/polars-parquet/src/parquet/metadata/row_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/row_metadata.rs @@ -1,38 +1,64 @@ +use std::sync::Arc; + +use hashbrown::hash_map::RawEntryMut; use parquet_format_safe::RowGroup; -#[cfg(feature = "serde_types")] -use serde::{Deserialize, Serialize}; +use polars_utils::aliases::{InitHashMaps, PlHashMap}; +use polars_utils::idx_vec::UnitVec; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::unitvec; -use super::column_chunk_metadata::ColumnChunkMetaData; +use super::column_chunk_metadata::{column_metadata_byte_range, ColumnChunkMetadata}; use super::schema_descriptor::SchemaDescriptor; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::write::ColumnOffsetsMetadata; + +type ColumnLookup = PlHashMap>; + +trait InitColumnLookup { + fn add_column(&mut self, index: usize, column: &ColumnChunkMetadata); +} + +impl InitColumnLookup for ColumnLookup { + #[inline(always)] + fn add_column(&mut self, index: usize, column: &ColumnChunkMetadata) { + let root_name = &column.descriptor().path_in_schema[0]; + + match self.raw_entry_mut().from_key(root_name) { + RawEntryMut::Vacant(slot) => { + slot.insert(root_name.clone(), unitvec![index]); + }, + RawEntryMut::Occupied(mut slot) => { + slot.get_mut().push(index); + }, + }; + } +} /// Metadata for a row group. #[derive(Debug, Clone, Default)] -#[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] pub struct RowGroupMetaData { - columns: Vec, + columns: Arc<[ColumnChunkMetadata]>, + column_lookup: PlHashMap>, num_rows: usize, total_byte_size: usize, + full_byte_range: core::ops::Range, } impl RowGroupMetaData { - /// Create a new [`RowGroupMetaData`] - pub fn new( - columns: Vec, - num_rows: usize, - total_byte_size: usize, - ) -> RowGroupMetaData { - Self { - columns, - num_rows, - total_byte_size, - } + #[inline(always)] + pub fn n_columns(&self) -> usize { + self.columns.len() } - /// Returns slice of column chunk metadata. - pub fn columns(&self) -> &[ColumnChunkMetaData] { - &self.columns + /// Fetch all columns under this root name. + pub fn columns_under_root_iter( + &self, + root_name: &str, + ) -> impl ExactSizeIterator + DoubleEndedIterator { + self.column_lookup + .get(root_name) + .unwrap() + .iter() + .map(|&x| &self.columns[x]) } /// Number of rows in this row group. @@ -53,6 +79,14 @@ impl RowGroupMetaData { .sum::() } + pub fn full_byte_range(&self) -> core::ops::Range { + self.full_byte_range.clone() + } + + pub fn byte_ranges_iter(&self) -> impl '_ + ExactSizeIterator> { + self.columns.iter().map(|x| x.byte_range()) + } + /// Method to convert from Thrift. pub(crate) fn try_from_thrift( schema_descr: &SchemaDescriptor, @@ -63,41 +97,42 @@ impl RowGroupMetaData { } let total_byte_size = rg.total_byte_size.try_into()?; let num_rows = rg.num_rows.try_into()?; + + let mut column_lookup = ColumnLookup::with_capacity(rg.columns.len()); + let mut full_byte_range = if let Some(first_column_chunk) = rg.columns.first() { + let Some(metadata) = &first_column_chunk.meta_data else { + return Err(ParquetError::oos("Column chunk requires metadata")); + }; + column_metadata_byte_range(metadata) + } else { + 0..0 + }; + let columns = rg .columns .into_iter() .zip(schema_descr.columns()) - .map(|(column_chunk, descriptor)| { - ColumnChunkMetaData::try_from_thrift(descriptor.clone(), column_chunk) + .enumerate() + .map(|(i, (column_chunk, descriptor))| { + let column = + ColumnChunkMetadata::try_from_thrift(descriptor.clone(), column_chunk)?; + + column_lookup.add_column(i, &column); + + let byte_range = column.byte_range(); + full_byte_range = full_byte_range.start.min(byte_range.start) + ..full_byte_range.end.max(byte_range.end); + + Ok(column) }) - .collect::>>()?; + .collect::>>()?; Ok(RowGroupMetaData { columns, + column_lookup, num_rows, total_byte_size, + full_byte_range, }) } - - /// Method to convert to Thrift. - pub(crate) fn into_thrift(self) -> RowGroup { - let file_offset = self - .columns - .iter() - .map(|c| { - ColumnOffsetsMetadata::from_column_chunk_metadata(c).calc_row_group_file_offset() - }) - .next() - .unwrap_or(None); - let total_compressed_size = Some(self.compressed_size() as i64); - RowGroup { - columns: self.columns.into_iter().map(|v| v.into_thrift()).collect(), - total_byte_size: self.total_byte_size as i64, - num_rows: self.num_rows as i64, - sorting_columns: None, - file_offset, - total_compressed_size, - ordinal: None, - } - } } diff --git a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs index 734ee054aebe..7c29f983ee1d 100644 --- a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs +++ b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs @@ -1,4 +1,5 @@ use parquet_format_safe::SchemaElement; +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; @@ -13,7 +14,7 @@ use crate::parquet::schema::Repetition; #[derive(Debug, Clone)] #[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] pub struct SchemaDescriptor { - name: String, + name: PlSmallStr, // The top-level schema (the "message" type). fields: Vec, @@ -24,7 +25,7 @@ pub struct SchemaDescriptor { impl SchemaDescriptor { /// Creates new schema descriptor from Parquet schema. - pub fn new(name: String, fields: Vec) -> Self { + pub fn new(name: PlSmallStr, fields: Vec) -> Self { let mut leaves = vec![]; for f in &fields { let mut path = vec![]; @@ -113,7 +114,7 @@ fn build_tree<'a>( match tp { ParquetType::PrimitiveType(p) => { - let path_in_schema = path_so_far.iter().copied().map(String::from).collect(); + let path_in_schema = path_so_far.iter().copied().map(Into::into).collect(); leaves.push(ColumnDescriptor::new( Descriptor { primitive_type: p.clone(), diff --git a/crates/polars-parquet/src/parquet/mod.rs b/crates/polars-parquet/src/parquet/mod.rs index f40b21ea0e04..ea6b5b2c8357 100644 --- a/crates/polars-parquet/src/parquet/mod.rs +++ b/crates/polars-parquet/src/parquet/mod.rs @@ -4,7 +4,6 @@ pub mod error; pub mod bloom_filter; pub mod compression; pub mod encoding; -pub mod indexes; pub mod metadata; pub mod page; mod parquet_bridge; diff --git a/crates/polars-parquet/src/parquet/page/mod.rs b/crates/polars-parquet/src/parquet/page/mod.rs index 62b3aa20163b..400bdfc4a0f7 100644 --- a/crates/polars-parquet/src/parquet/page/mod.rs +++ b/crates/polars-parquet/src/parquet/page/mod.rs @@ -2,7 +2,6 @@ use super::CowBuffer; use crate::parquet::compression::Compression; use crate::parquet::encoding::{get_length, Encoding}; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::indexes::Interval; use crate::parquet::metadata::Descriptor; pub use crate::parquet::parquet_bridge::{DataPageHeaderExt, PageType}; use crate::parquet::statistics::Statistics; @@ -24,9 +23,7 @@ pub struct CompressedDataPage { pub(crate) compression: Compression, uncompressed_page_size: usize, pub(crate) descriptor: Descriptor, - - // The offset and length in rows - pub(crate) selected_rows: Option>, + pub num_rows: Option, } impl CompressedDataPage { @@ -37,16 +34,16 @@ impl CompressedDataPage { compression: Compression, uncompressed_page_size: usize, descriptor: Descriptor, - rows: Option, + num_rows: usize, ) -> Self { - Self::new_read( + Self { header, buffer, compression, uncompressed_page_size, descriptor, - rows.map(|x| vec![Interval::new(0, x)]), - ) + num_rows: Some(num_rows), + } } /// Returns a new [`CompressedDataPage`]. @@ -56,7 +53,6 @@ impl CompressedDataPage { compression: Compression, uncompressed_page_size: usize, descriptor: Descriptor, - selected_rows: Option>, ) -> Self { Self { header, @@ -64,7 +60,7 @@ impl CompressedDataPage { compression, uncompressed_page_size, descriptor, - selected_rows, + num_rows: None, } } @@ -87,16 +83,14 @@ impl CompressedDataPage { self.compression } - /// the rows to be selected by this page. - /// When `None`, all rows are to be considered. - pub fn selected_rows(&self) -> Option<&[Interval]> { - self.selected_rows.as_deref() - } - pub fn num_values(&self) -> usize { self.header.num_values() } + pub fn num_rows(&self) -> Option { + self.num_rows + } + /// Decodes the raw statistics into a statistics pub fn statistics(&self) -> Option> { match &self.header { @@ -111,11 +105,6 @@ impl CompressedDataPage { } } - #[inline] - pub fn select_rows(&mut self, selected_rows: Vec) { - self.selected_rows = Some(selected_rows); - } - pub fn slice_mut(&mut self) -> &mut CowBuffer { &mut self.buffer } @@ -134,6 +123,13 @@ impl DataPageHeader { DataPageHeader::V2(d) => d.num_values as usize, } } + + pub fn null_count(&self) -> Option { + match &self { + DataPageHeader::V1(_) => None, + DataPageHeader::V2(d) => Some(d.num_nulls as usize), + } + } } /// A [`DataPage`] is an uncompressed, encoded representation of a Parquet data page. It holds actual data @@ -143,7 +139,7 @@ pub struct DataPage { pub(super) header: DataPageHeader, pub(super) buffer: CowBuffer, pub descriptor: Descriptor, - pub selected_rows: Option>, + pub num_rows: Option, } impl DataPage { @@ -151,27 +147,26 @@ impl DataPage { header: DataPageHeader, buffer: CowBuffer, descriptor: Descriptor, - rows: Option, + num_rows: usize, ) -> Self { - Self::new_read( + Self { header, buffer, descriptor, - rows.map(|x| vec![Interval::new(0, x)]), - ) + num_rows: Some(num_rows), + } } pub(crate) fn new_read( header: DataPageHeader, buffer: CowBuffer, descriptor: Descriptor, - selected_rows: Option>, ) -> Self { Self { header, buffer, descriptor, - selected_rows, + num_rows: None, } } @@ -183,12 +178,6 @@ impl DataPage { &self.buffer } - /// the rows to be selected by this page. - /// When `None`, all rows are to be considered. - pub fn selected_rows(&self) -> Option<&[Interval]> { - self.selected_rows.as_deref() - } - /// Returns a mutable reference to the internal buffer. /// Useful to recover the buffer after the page has been decoded. pub fn buffer_mut(&mut self) -> &mut Vec { @@ -199,6 +188,14 @@ impl DataPage { self.header.num_values() } + pub fn null_count(&self) -> Option { + self.header.null_count() + } + + pub fn num_rows(&self) -> Option { + self.num_rows + } + pub fn encoding(&self) -> Encoding { match &self.header { DataPageHeader::V1(d) => d.encoding(), @@ -272,13 +269,6 @@ pub enum CompressedPage { } impl CompressedPage { - pub(crate) fn buffer(&self) -> &[u8] { - match self { - CompressedPage::Data(page) => &page.buffer, - CompressedPage::Dict(page) => &page.buffer, - } - } - pub(crate) fn buffer_mut(&mut self) -> &mut Vec { match self { CompressedPage::Data(page) => page.buffer.to_mut(), @@ -300,17 +290,10 @@ impl CompressedPage { } } - pub(crate) fn selected_rows(&self) -> Option<&[Interval]> { - match self { - CompressedPage::Data(page) => page.selected_rows(), - CompressedPage::Dict(_) => None, - } - } - - pub(crate) fn uncompressed_size(&self) -> usize { + pub(crate) fn num_rows(&self) -> Option { match self { - CompressedPage::Data(page) => page.uncompressed_page_size, - CompressedPage::Dict(page) => page.uncompressed_page_size, + CompressedPage::Data(page) => page.num_rows(), + CompressedPage::Dict(_) => Some(0), } } } diff --git a/crates/polars-parquet/src/parquet/read/column/mod.rs b/crates/polars-parquet/src/parquet/read/column/mod.rs index 2cd15c4f61e6..54065389328e 100644 --- a/crates/polars-parquet/src/parquet/read/column/mod.rs +++ b/crates/polars-parquet/src/parquet/read/column/mod.rs @@ -1,16 +1,13 @@ -use std::io::{Read, Seek}; use std::vec::IntoIter; -use super::{get_field_columns, get_page_iterator, MemReader, PageFilter, PageReader}; +use polars_utils::idx_vec::UnitVec; + +use super::{get_page_iterator, MemReader, PageReader}; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::metadata::{ColumnChunkMetaData, RowGroupMetaData}; +use crate::parquet::metadata::{ColumnChunkMetadata, RowGroupMetaData}; use crate::parquet::page::CompressedPage; use crate::parquet::schema::types::ParquetType; -#[cfg(feature = "async")] -#[cfg_attr(docsrs, doc(cfg(feature = "async")))] -mod stream; - /// Returns a [`ColumnIterator`] of column chunks corresponding to `field`. /// /// Contrarily to [`get_page_iterator`] that returns a single iterator of pages, this iterator @@ -18,18 +15,17 @@ mod stream; /// For primitive fields (e.g. `i64`), [`ColumnIterator`] yields exactly one column. /// For complex fields, it yields multiple columns. /// `max_page_size` is the maximum number of bytes allowed. -pub fn get_column_iterator( +pub fn get_column_iterator<'a>( reader: MemReader, - row_group: &RowGroupMetaData, + row_group: &'a RowGroupMetaData, field_name: &str, - page_filter: Option, max_page_size: usize, -) -> ColumnIterator { - let columns = get_field_columns(row_group.columns(), field_name) - .cloned() - .collect::>(); - - ColumnIterator::new(reader, columns, page_filter, max_page_size) +) -> ColumnIterator<'a> { + let columns = row_group + .columns_under_root_iter(field_name) + .rev() + .collect::>(); + ColumnIterator::new(reader, columns, max_page_size) } /// State of [`MutStreamingIterator`]. @@ -52,34 +48,30 @@ pub trait MutStreamingIterator: Sized { /// A [`MutStreamingIterator`] that reads column chunks one by one, /// returning a [`PageReader`] per column. -pub struct ColumnIterator { +pub struct ColumnIterator<'a> { reader: MemReader, - columns: Vec, - page_filter: Option, + columns: UnitVec<&'a ColumnChunkMetadata>, max_page_size: usize, } -impl ColumnIterator { +impl<'a> ColumnIterator<'a> { /// Returns a new [`ColumnIterator`] /// `max_page_size` is the maximum allowed page size pub fn new( reader: MemReader, - mut columns: Vec, - page_filter: Option, + columns: UnitVec<&'a ColumnChunkMetadata>, max_page_size: usize, ) -> Self { - columns.reverse(); Self { reader, columns, - page_filter, max_page_size, } } } -impl Iterator for ColumnIterator { - type Item = ParquetResult<(PageReader, ColumnChunkMetaData)>; +impl<'a> Iterator for ColumnIterator<'a> { + type Item = ParquetResult<(PageReader, &'a ColumnChunkMetadata)>; fn next(&mut self) -> Option { if self.columns.is_empty() { @@ -87,16 +79,11 @@ impl Iterator for ColumnIterator { }; let column = self.columns.pop().unwrap(); - let iter = match get_page_iterator( - &column, - self.reader.clone(), - self.page_filter.clone(), - Vec::new(), - self.max_page_size, - ) { - Err(e) => return Some(Err(e)), - Ok(v) => v, - }; + let iter = + match get_page_iterator(column, self.reader.clone(), Vec::new(), self.max_page_size) { + Err(e) => return Some(Err(e)), + Ok(v) => v, + }; Some(Ok((iter, column))) } } @@ -107,11 +94,11 @@ pub struct ReadColumnIterator { field: ParquetType, chunks: Vec<( Vec>, - ColumnChunkMetaData, + ColumnChunkMetadata, )>, current: Option<( IntoIter>, - ColumnChunkMetaData, + ColumnChunkMetadata, )>, } @@ -121,7 +108,7 @@ impl ReadColumnIterator { field: ParquetType, chunks: Vec<( Vec>, - ColumnChunkMetaData, + ColumnChunkMetadata, )>, ) -> Self { Self { @@ -135,7 +122,7 @@ impl ReadColumnIterator { impl MutStreamingIterator for ReadColumnIterator { type Item = ( IntoIter>, - ColumnChunkMetaData, + ColumnChunkMetadata, ); type Error = ParquetError; @@ -158,38 +145,3 @@ impl MutStreamingIterator for ReadColumnIterator { self.current.as_mut() } } - -/// Reads all columns that are part of the parquet field `field_name` -/// # Implementation -/// This operation is IO-bounded `O(C)` where C is the number of columns associated to -/// the field (one for non-nested types) -/// It reads the columns sequentially. Use [`read_column`] to fork this operation to multiple -/// readers. -pub fn read_columns<'a, R: Read + Seek>( - reader: &mut R, - columns: &'a [ColumnChunkMetaData], - field_name: &'a str, -) -> Result)>, ParquetError> { - get_field_columns(columns, field_name) - .map(|column| read_column(reader, column).map(|c| (column, c))) - .collect() -} - -/// Reads a column chunk into memory -/// This operation is IO-bounded and allocates the column's `compressed_size`. -pub fn read_column(reader: &mut R, column: &ColumnChunkMetaData) -> Result, ParquetError> -where - R: Read + Seek, -{ - let (start, length) = column.byte_range(); - reader.seek(std::io::SeekFrom::Start(start))?; - - let mut chunk = vec![]; - chunk.try_reserve(length as usize)?; - reader.by_ref().take(length).read_to_end(&mut chunk)?; - Ok(chunk) -} - -#[cfg(feature = "async")] -#[cfg_attr(docsrs, doc(cfg(feature = "async")))] -pub use stream::{read_column_async, read_columns_async}; diff --git a/crates/polars-parquet/src/parquet/read/column/stream.rs b/crates/polars-parquet/src/parquet/read/column/stream.rs deleted file mode 100644 index 63319d2260c6..000000000000 --- a/crates/polars-parquet/src/parquet/read/column/stream.rs +++ /dev/null @@ -1,51 +0,0 @@ -use futures::future::{try_join_all, BoxFuture}; -use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; - -use crate::parquet::error::ParquetError; -use crate::parquet::metadata::ColumnChunkMetaData; -use crate::parquet::read::get_field_columns; - -/// Reads a single column chunk into memory asynchronously -pub async fn read_column_async<'b, R, F>( - factory: F, - meta: &ColumnChunkMetaData, -) -> Result, ParquetError> -where - R: AsyncRead + AsyncSeek + Send + Unpin, - F: Fn() -> BoxFuture<'b, std::io::Result>, -{ - let mut reader = factory().await?; - let (start, length) = meta.byte_range(); - reader.seek(std::io::SeekFrom::Start(start)).await?; - - let mut chunk = vec![]; - chunk.try_reserve(length as usize)?; - reader.take(length).read_to_end(&mut chunk).await?; - Result::Ok(chunk) -} - -/// Reads all columns that are part of the parquet field `field_name` -/// # Implementation -/// This operation is IO-bounded `O(C)` where C is the number of columns associated to -/// the field (one for non-nested types) -/// -/// It does so asynchronously via a single `join_all` over all the necessary columns for -/// `field_name`. -pub async fn read_columns_async< - 'a, - 'b, - R: AsyncRead + AsyncSeek + Send + Unpin, - F: Fn() -> BoxFuture<'b, std::io::Result> + Clone, ->( - factory: F, - columns: &'a [ColumnChunkMetaData], - field_name: &'a str, -) -> Result)>, ParquetError> { - let fields = get_field_columns(columns, field_name).collect::>(); - let futures = fields - .iter() - .map(|meta| async { read_column_async(factory.clone(), meta).await }); - - let columns = try_join_all(futures).await?; - Ok(fields.into_iter().zip(columns).collect()) -} diff --git a/crates/polars-parquet/src/parquet/read/compression.rs b/crates/polars-parquet/src/parquet/read/compression.rs index 91ff7a519c61..a79989c39e26 100644 --- a/crates/polars-parquet/src/parquet/read/compression.rs +++ b/crates/polars-parquet/src/parquet/read/compression.rs @@ -3,7 +3,9 @@ use parquet_format_safe::DataPageHeaderV2; use super::PageReader; use crate::parquet::compression::{self, Compression}; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::page::{CompressedPage, DataPage, DataPageHeader, DictPage, Page}; +use crate::parquet::page::{ + CompressedDataPage, CompressedPage, DataPage, DataPageHeader, DictPage, Page, +}; use crate::parquet::CowBuffer; fn decompress_v1( @@ -52,76 +54,72 @@ fn decompress_v2( Ok(()) } -/// decompresses a [`CompressedDataPage`] into `buffer`. -/// If the page is un-compressed, `buffer` is swapped instead. -/// Returns whether the page was decompressed. -pub fn decompress_buffer( - compressed_page: &mut CompressedPage, - buffer: &mut Vec, -) -> ParquetResult { - if compressed_page.compression() != Compression::Uncompressed { - // prepare the compression buffer - let read_size = compressed_page.uncompressed_size(); - - if read_size > buffer.capacity() { - // dealloc and ignore region, replacing it by a new region. - // This won't reallocate - it frees and calls `alloc_zeroed` - *buffer = vec![0; read_size]; - } else if read_size > buffer.len() { - // fill what we need with zeros so that we can use them in `Read`. - // This won't reallocate - buffer.resize(read_size, 0); - } else { - buffer.truncate(read_size); - } - match compressed_page { - CompressedPage::Data(compressed_page) => match compressed_page.header() { - DataPageHeader::V1(_) => { - decompress_v1(&compressed_page.buffer, compressed_page.compression, buffer)? - }, - DataPageHeader::V2(header) => decompress_v2( - &compressed_page.buffer, - header, - compressed_page.compression, - buffer, - )?, - }, - CompressedPage::Dict(page) => decompress_v1(&page.buffer, page.compression(), buffer)?, - } - Ok(true) - } else { - // page.buffer is already decompressed => swap it with `buffer`, making `page.buffer` the - // decompression buffer and `buffer` the decompressed buffer - std::mem::swap(&mut compressed_page.buffer().to_vec(), buffer); - Ok(false) - } -} - -fn create_page(compressed_page: CompressedPage, buffer: Vec) -> Page { - match compressed_page { - CompressedPage::Data(page) => Page::Data(DataPage::new_read( +/// Decompresses the page, using `buffer` for decompression. +/// If `page.buffer.len() == 0`, there was no decompression and the buffer was moved. +/// Else, decompression took place. +pub fn decompress(compressed_page: CompressedPage, buffer: &mut Vec) -> ParquetResult { + Ok(match (compressed_page.compression(), compressed_page) { + (Compression::Uncompressed, CompressedPage::Data(page)) => Page::Data(DataPage::new_read( page.header, - CowBuffer::Owned(buffer), + page.buffer, page.descriptor, - page.selected_rows, )), - CompressedPage::Dict(page) => Page::Dict(DictPage { - buffer: CowBuffer::Owned(buffer), + (_, CompressedPage::Data(page)) => { + // prepare the compression buffer + let read_size = page.uncompressed_size(); + + if read_size > buffer.capacity() { + // dealloc and ignore region, replacing it by a new region. + // This won't reallocate - it frees and calls `alloc_zeroed` + *buffer = vec![0; read_size]; + } else if read_size > buffer.len() { + // fill what we need with zeros so that we can use them in `Read`. + // This won't reallocate + buffer.resize(read_size, 0); + } else { + buffer.truncate(read_size); + } + + match page.header() { + DataPageHeader::V1(_) => decompress_v1(&page.buffer, page.compression, buffer)?, + DataPageHeader::V2(header) => { + decompress_v2(&page.buffer, header, page.compression, buffer)? + }, + } + let buffer = CowBuffer::Owned(std::mem::take(buffer)); + + Page::Data(DataPage::new_read(page.header, buffer, page.descriptor)) + }, + (Compression::Uncompressed, CompressedPage::Dict(page)) => Page::Dict(DictPage { + buffer: page.buffer, num_values: page.num_values, is_sorted: page.is_sorted, }), - } -} - -/// Decompresses the page, using `buffer` for decompression. -/// If `page.buffer.len() == 0`, there was no decompression and the buffer was moved. -/// Else, decompression took place. -pub fn decompress( - mut compressed_page: CompressedPage, - buffer: &mut Vec, -) -> ParquetResult { - decompress_buffer(&mut compressed_page, buffer)?; - Ok(create_page(compressed_page, std::mem::take(buffer))) + (_, CompressedPage::Dict(page)) => { + // prepare the compression buffer + let read_size = page.uncompressed_page_size; + + if read_size > buffer.capacity() { + // dealloc and ignore region, replacing it by a new region. + // This won't reallocate - it frees and calls `alloc_zeroed` + *buffer = vec![0; read_size]; + } else if read_size > buffer.len() { + // fill what we need with zeros so that we can use them in `Read`. + // This won't reallocate + buffer.resize(read_size, 0); + } else { + buffer.truncate(read_size); + } + decompress_v1(&page.buffer, page.compression(), buffer)?; + let buffer = CowBuffer::Owned(std::mem::take(buffer)); + + Page::Dict(DictPage { + buffer, + num_values: page.num_values, + is_sorted: page.is_sorted, + }) + }, + }) } type _Decompressor = streaming_decompression::Decompressor< @@ -205,8 +203,27 @@ impl BasicDecompressor { } } +pub struct DataPageItem { + page: CompressedDataPage, +} + +impl DataPageItem { + pub fn num_values(&self) -> usize { + self.page.num_values() + } + + pub fn decompress(self, decompressor: &mut BasicDecompressor) -> ParquetResult { + let p = decompress(CompressedPage::Data(self.page), &mut decompressor.buffer)?; + let Page::Data(p) = p else { + panic!("Decompressing a data page should result in a data page"); + }; + + Ok(p) + } +} + impl Iterator for BasicDecompressor { - type Item = ParquetResult; + type Item = ParquetResult; fn next(&mut self) -> Option { let page = match self.reader.next() { @@ -215,13 +232,13 @@ impl Iterator for BasicDecompressor { Some(Ok(p)) => p, }; - Some(decompress(page, &mut self.buffer).map(|p| { - if let Page::Data(p) = p { - p - } else { - panic!("Found compressed page in the middle of the pages") - } - })) + let CompressedPage::Data(page) = page else { + return Some(Err(ParquetError::oos( + "Found dictionary page beyond the first page of a column chunk", + ))); + }; + + Some(Ok(DataPageItem { page })) } fn size_hint(&self) -> (usize, Option) { diff --git a/crates/polars-parquet/src/parquet/read/indexes/deserialize.rs b/crates/polars-parquet/src/parquet/read/indexes/deserialize.rs deleted file mode 100644 index d6bfb4de8a06..000000000000 --- a/crates/polars-parquet/src/parquet/read/indexes/deserialize.rs +++ /dev/null @@ -1,30 +0,0 @@ -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; -use parquet_format_safe::ColumnIndex; - -use crate::parquet::error::ParquetError; -use crate::parquet::indexes::{BooleanIndex, ByteIndex, FixedLenByteIndex, Index, NativeIndex}; -use crate::parquet::schema::types::{PhysicalType, PrimitiveType}; - -pub fn deserialize( - data: &[u8], - primitive_type: PrimitiveType, -) -> Result, ParquetError> { - let mut prot = TCompactInputProtocol::new(data, data.len() * 2 + 1024); - - let index = ColumnIndex::read_from_in_protocol(&mut prot)?; - - let index = match primitive_type.physical_type { - PhysicalType::Boolean => Box::new(BooleanIndex::try_new(index)?) as Box, - PhysicalType::Int32 => Box::new(NativeIndex::::try_new(index, primitive_type)?), - PhysicalType::Int64 => Box::new(NativeIndex::::try_new(index, primitive_type)?), - PhysicalType::Int96 => Box::new(NativeIndex::<[u32; 3]>::try_new(index, primitive_type)?), - PhysicalType::Float => Box::new(NativeIndex::::try_new(index, primitive_type)?), - PhysicalType::Double => Box::new(NativeIndex::::try_new(index, primitive_type)?), - PhysicalType::ByteArray => Box::new(ByteIndex::try_new(index, primitive_type)?), - PhysicalType::FixedLenByteArray(_) => { - Box::new(FixedLenByteIndex::try_new(index, primitive_type)?) - }, - }; - - Ok(index) -} diff --git a/crates/polars-parquet/src/parquet/read/indexes/mod.rs b/crates/polars-parquet/src/parquet/read/indexes/mod.rs deleted file mode 100644 index 1e1919c84c75..000000000000 --- a/crates/polars-parquet/src/parquet/read/indexes/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod deserialize; -mod read; - -pub use read::*; diff --git a/crates/polars-parquet/src/parquet/read/indexes/read.rs b/crates/polars-parquet/src/parquet/read/indexes/read.rs deleted file mode 100644 index 1dbb5aa20fde..000000000000 --- a/crates/polars-parquet/src/parquet/read/indexes/read.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::io::{Cursor, Read, Seek, SeekFrom}; - -use parquet_format_safe::thrift::protocol::TCompactInputProtocol; -use parquet_format_safe::{ColumnChunk, OffsetIndex, PageLocation}; - -use super::deserialize::deserialize; -use crate::parquet::error::ParquetError; -use crate::parquet::indexes::Index; -use crate::parquet::metadata::ColumnChunkMetaData; - -fn prepare_read Option, G: Fn(&ColumnChunk) -> Option>( - chunks: &[ColumnChunkMetaData], - get_offset: F, - get_length: G, -) -> Result<(u64, Vec), ParquetError> { - // c1: [start, length] - // ... - // cN: [start, length] - - let first_chunk = if let Some(chunk) = chunks.first() { - chunk - } else { - return Ok((0, vec![])); - }; - let metadata = first_chunk.column_chunk(); - - let offset: u64 = if let Some(offset) = get_offset(metadata) { - offset.try_into()? - } else { - return Ok((0, vec![])); - }; - - let lengths = chunks - .iter() - .map(|x| get_length(x.column_chunk())) - .map(|maybe_length| { - let index_length = maybe_length.ok_or_else(|| { - ParquetError::oos("The column length must exist if column offset exists") - })?; - - Ok(index_length.try_into()?) - }) - .collect::, ParquetError>>()?; - - Ok((offset, lengths)) -} - -fn prepare_column_index_read( - chunks: &[ColumnChunkMetaData], -) -> Result<(u64, Vec), ParquetError> { - prepare_read(chunks, |x| x.column_index_offset, |x| x.column_index_length) -} - -fn prepare_offset_index_read( - chunks: &[ColumnChunkMetaData], -) -> Result<(u64, Vec), ParquetError> { - prepare_read(chunks, |x| x.offset_index_offset, |x| x.offset_index_length) -} - -fn deserialize_column_indexes( - chunks: &[ColumnChunkMetaData], - data: &[u8], - lengths: Vec, -) -> Result>, ParquetError> { - let mut start = 0; - let data = lengths.into_iter().map(|length| { - let r = &data[start..start + length]; - start += length; - r - }); - - chunks - .iter() - .zip(data) - .map(|(chunk, data)| { - let primitive_type = chunk.descriptor().descriptor.primitive_type.clone(); - deserialize(data, primitive_type) - }) - .collect() -} - -/// Reads the column indexes of all [`ColumnChunkMetaData`] and deserializes them into [`Index`]. -/// Returns an empty vector if indexes are not available -pub fn read_columns_indexes( - reader: &mut R, - chunks: &[ColumnChunkMetaData], -) -> Result>, ParquetError> { - let (offset, lengths) = prepare_column_index_read(chunks)?; - - let length = lengths.iter().sum::(); - - reader.seek(SeekFrom::Start(offset))?; - - let mut data = vec![]; - data.try_reserve(length)?; - reader.by_ref().take(length as u64).read_to_end(&mut data)?; - - deserialize_column_indexes(chunks, &data, lengths) -} - -fn deserialize_page_locations( - data: &[u8], - column_number: usize, -) -> Result>, ParquetError> { - let len = data.len() * 2 + 1024; - let mut reader = Cursor::new(data); - - (0..column_number) - .map(|_| { - let mut prot = TCompactInputProtocol::new(&mut reader, len); - let offset = OffsetIndex::read_from_in_protocol(&mut prot)?; - Ok(offset.page_locations) - }) - .collect() -} - -/// Read [`PageLocation`]s from the [`ColumnChunkMetaData`]s. -/// Returns an empty vector if indexes are not available -pub fn read_pages_locations( - reader: &mut R, - chunks: &[ColumnChunkMetaData], -) -> Result>, ParquetError> { - let (offset, lengths) = prepare_offset_index_read(chunks)?; - - let length = lengths.iter().sum::(); - - reader.seek(SeekFrom::Start(offset))?; - - let mut data = vec![]; - data.try_reserve(length)?; - reader.by_ref().take(length as u64).read_to_end(&mut data)?; - - deserialize_page_locations(&data, chunks.len()) -} diff --git a/crates/polars-parquet/src/parquet/read/mod.rs b/crates/polars-parquet/src/parquet/read/mod.rs index cea8561193ef..c3ec112e6864 100644 --- a/crates/polars-parquet/src/parquet/read/mod.rs +++ b/crates/polars-parquet/src/parquet/read/mod.rs @@ -1,6 +1,5 @@ mod column; mod compression; -mod indexes; pub mod levels; mod metadata; mod page; @@ -8,67 +7,33 @@ mod page; mod stream; use std::io::{Seek, SeekFrom}; -use std::sync::Arc; pub use column::*; pub use compression::{decompress, BasicDecompressor}; -pub use indexes::{read_columns_indexes, read_pages_locations}; pub use metadata::{deserialize_metadata, read_metadata, read_metadata_with_size}; #[cfg(feature = "async")] pub use page::{get_page_stream, get_page_stream_from_column_start}; -pub use page::{PageFilter, PageIterator, PageMetaData, PageReader}; +pub use page::{PageIterator, PageMetaData, PageReader}; use polars_utils::mmap::MemReader; #[cfg(feature = "async")] pub use stream::read_metadata as read_metadata_async; use crate::parquet::error::ParquetResult; -use crate::parquet::metadata::{ColumnChunkMetaData, FileMetaData, RowGroupMetaData}; - -/// Filters row group metadata to only those row groups, -/// for which the predicate function returns true -pub fn filter_row_groups( - metadata: &FileMetaData, - predicate: &dyn Fn(&RowGroupMetaData, usize) -> bool, -) -> FileMetaData { - let mut filtered_row_groups = Vec::::new(); - for (i, row_group_metadata) in metadata.row_groups.iter().enumerate() { - if predicate(row_group_metadata, i) { - filtered_row_groups.push(row_group_metadata.clone()); - } - } - let mut metadata = metadata.clone(); - metadata.row_groups = filtered_row_groups; - metadata -} +use crate::parquet::metadata::ColumnChunkMetadata; /// Returns a new [`PageReader`] by seeking `reader` to the beginning of `column_chunk`. pub fn get_page_iterator( - column_chunk: &ColumnChunkMetaData, + column_chunk: &ColumnChunkMetadata, mut reader: MemReader, - pages_filter: Option, scratch: Vec, max_page_size: usize, ) -> ParquetResult { - let pages_filter = pages_filter.unwrap_or_else(|| Arc::new(|_, _| true)); - - let (col_start, _) = column_chunk.byte_range(); + let col_start = column_chunk.byte_range().start; reader.seek(SeekFrom::Start(col_start))?; Ok(PageReader::new( reader, column_chunk, - pages_filter, scratch, max_page_size, )) } - -/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. -/// For non-nested types, this returns an iterator with a single column -pub fn get_field_columns<'a>( - columns: &'a [ColumnChunkMetaData], - field_name: &'a str, -) -> impl Iterator { - columns - .iter() - .filter(move |x| x.descriptor().path_in_schema[0] == field_name) -} diff --git a/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs b/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs deleted file mode 100644 index 90788d0a7320..000000000000 --- a/crates/polars-parquet/src/parquet/read/page/indexed_reader.rs +++ /dev/null @@ -1,189 +0,0 @@ -use std::collections::VecDeque; -use std::io::{Seek, SeekFrom}; - -use polars_utils::mmap::{MemReader, MemSlice}; - -use super::reader::{finish_page, read_page_header, PageMetaData}; -use crate::parquet::error::ParquetError; -use crate::parquet::indexes::{FilteredPage, Interval}; -use crate::parquet::metadata::{ColumnChunkMetaData, Descriptor}; -use crate::parquet::page::{CompressedDictPage, CompressedPage, ParquetPageHeader}; -use crate::parquet::parquet_bridge::Compression; - -#[derive(Debug, Clone, Copy)] -enum State { - MaybeDict, - Data, -} - -/// A fallible [`Iterator`] of [`CompressedPage`]. This iterator leverages page indexes -/// to skip pages that are not needed. Consequently, the pages from this -/// iterator always have [`Some`] [`crate::parquet::page::CompressedDataPage::selected_rows()`] -pub struct IndexedPageReader { - // The source - reader: MemReader, - - column_start: u64, - compression: Compression, - - // used to deserialize dictionary pages and attach the descriptor to every read page - descriptor: Descriptor, - - // buffer to read the whole page [header][data] into memory - buffer: Vec, - - // buffer to store the data [data] and reuse across pages - data_buffer: Vec, - - pages: VecDeque, - - state: State, -} - -fn read_page( - reader: &mut MemReader, - start: u64, - length: usize, -) -> Result<(ParquetPageHeader, MemSlice), ParquetError> { - // seek to the page - reader.seek(SeekFrom::Start(start))?; - - let start_position = reader.position(); - - // deserialize [header] - let page_header = read_page_header(reader, 1024 * 1024)?; - let header_size = reader.position() - start_position; - - // copy [data] - let data = reader.read_slice(length - header_size); - - Ok((page_header, data)) -} - -fn read_dict_page( - reader: &mut MemReader, - start: u64, - length: usize, - compression: Compression, - descriptor: &Descriptor, -) -> Result { - let (page_header, data) = read_page(reader, start, length)?; - - let page = finish_page(page_header, data, compression, descriptor, None)?; - if let CompressedPage::Dict(page) = page { - Ok(page) - } else { - Err(ParquetError::oos( - "The first page is not a dictionary page but it should", - )) - } -} - -impl IndexedPageReader { - /// Returns a new [`IndexedPageReader`]. - pub fn new( - reader: MemReader, - column: &ColumnChunkMetaData, - pages: Vec, - buffer: Vec, - data_buffer: Vec, - ) -> Self { - Self::new_with_page_meta(reader, column.into(), pages, buffer, data_buffer) - } - - /// Returns a new [`IndexedPageReader`] with [`PageMetaData`]. - pub fn new_with_page_meta( - reader: MemReader, - column: PageMetaData, - pages: Vec, - buffer: Vec, - data_buffer: Vec, - ) -> Self { - let pages = pages.into_iter().collect(); - Self { - reader, - column_start: column.column_start, - compression: column.compression, - descriptor: column.descriptor, - buffer, - data_buffer, - pages, - state: State::MaybeDict, - } - } - - /// consumes self into the reader and the two internal buffers - pub fn into_inner(self) -> (MemReader, Vec, Vec) { - (self.reader, self.buffer, self.data_buffer) - } - - fn read_page( - &mut self, - start: u64, - length: usize, - selected_rows: Vec, - ) -> Result { - let (page_header, data) = read_page(&mut self.reader, start, length)?; - - finish_page( - page_header, - data, - self.compression, - &self.descriptor, - Some(selected_rows), - ) - } - - fn read_dict(&mut self) -> Option> { - // a dictionary page exists iff the first data page is not at the start of - // the column - let (start, length) = match self.pages.front() { - Some(page) => { - let length = (page.start - self.column_start) as usize; - if length > 0 { - (self.column_start, length) - } else { - return None; - } - }, - None => return None, - }; - - let maybe_page = read_dict_page( - &mut self.reader, - start, - length, - self.compression, - &self.descriptor, - ); - Some(maybe_page.map(CompressedPage::Dict)) - } -} - -impl Iterator for IndexedPageReader { - type Item = Result; - - fn next(&mut self) -> Option { - match self.state { - State::MaybeDict => { - self.state = State::Data; - if let Some(dict) = self.read_dict() { - Some(dict) - } else { - self.next() - } - }, - State::Data => { - if let Some(page) = self.pages.pop_front() { - if page.selected_rows.is_empty() { - self.next() - } else { - Some(self.read_page(page.start, page.length, page.selected_rows)) - } - } else { - None - } - }, - } - } -} diff --git a/crates/polars-parquet/src/parquet/read/page/mod.rs b/crates/polars-parquet/src/parquet/read/page/mod.rs index 98d76493ba50..14801839a693 100644 --- a/crates/polars-parquet/src/parquet/read/page/mod.rs +++ b/crates/polars-parquet/src/parquet/read/page/mod.rs @@ -2,7 +2,7 @@ mod reader; #[cfg(feature = "async")] mod stream; -pub use reader::{PageFilter, PageMetaData, PageReader}; +pub use reader::{PageMetaData, PageReader}; use crate::parquet::error::ParquetError; use crate::parquet::page::CompressedPage; diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs index cf01a25d7e07..cd23af0499d7 100644 --- a/crates/polars-parquet/src/parquet/read/page/reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -1,5 +1,5 @@ use std::io::Seek; -use std::sync::{Arc, OnceLock}; +use std::sync::OnceLock; use parquet_format_safe::thrift::protocol::TCompactInputProtocol; use polars_utils::mmap::{MemReader, MemSlice}; @@ -7,15 +7,14 @@ use polars_utils::mmap::{MemReader, MemSlice}; use super::PageIterator; use crate::parquet::compression::Compression; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::indexes::Interval; -use crate::parquet::metadata::{ColumnChunkMetaData, Descriptor}; +use crate::parquet::metadata::{ColumnChunkMetadata, Descriptor}; use crate::parquet::page::{ CompressedDataPage, CompressedDictPage, CompressedPage, DataPageHeader, PageType, ParquetPageHeader, }; use crate::parquet::CowBuffer; -/// This meta is a small part of [`ColumnChunkMetaData`]. +/// This meta is a small part of [`ColumnChunkMetadata`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PageMetaData { /// The start offset of this column chunk in file. @@ -45,10 +44,10 @@ impl PageMetaData { } } -impl From<&ColumnChunkMetaData> for PageMetaData { - fn from(column: &ColumnChunkMetaData) -> Self { +impl From<&ColumnChunkMetadata> for PageMetaData { + fn from(column: &ColumnChunkMetadata) -> Self { Self { - column_start: column.byte_range().0, + column_start: column.byte_range().start, num_values: column.num_values(), compression: column.compression(), descriptor: column.descriptor().descriptor.clone(), @@ -56,11 +55,9 @@ impl From<&ColumnChunkMetaData> for PageMetaData { } } -/// Type declaration for a page filter -pub type PageFilter = Arc bool + Send + Sync>; - /// A fallible [`Iterator`] of [`CompressedDataPage`]. This iterator reads pages back /// to back until all pages have been consumed. +/// /// The pages from this iterator always have [`None`] [`crate::parquet::page::CompressedDataPage::selected_rows()`] since /// filter pushdown is not supported without a /// pre-computed [page index](https://github.com/apache/parquet-format/blob/master/PageIndex.md). @@ -76,8 +73,6 @@ pub struct PageReader { // The number of total values in this column chunk. total_num_values: i64, - pages_filter: PageFilter, - descriptor: Descriptor, // The currently allocated buffer. @@ -94,12 +89,11 @@ impl PageReader { /// The parameter `max_header_size` pub fn new( reader: MemReader, - column: &ColumnChunkMetaData, - pages_filter: PageFilter, + column: &ColumnChunkMetadata, scratch: Vec, max_page_size: usize, ) -> Self { - Self::new_with_page_meta(reader, column.into(), pages_filter, scratch, max_page_size) + Self::new_with_page_meta(reader, column.into(), scratch, max_page_size) } /// Create a a new [`PageReader`] with [`PageMetaData`]. @@ -108,7 +102,6 @@ impl PageReader { pub fn new_with_page_meta( reader: MemReader, reader_meta: PageMetaData, - pages_filter: PageFilter, scratch: Vec, max_page_size: usize, ) -> Self { @@ -118,7 +111,6 @@ impl PageReader { compression: reader_meta.compression, seen_num_values: 0, descriptor: reader_meta.descriptor, - pages_filter, scratch, max_page_size, } @@ -135,6 +127,12 @@ impl PageReader { } pub fn read_dict(&mut self) -> ParquetResult> { + // If there are no pages, we cannot check if the first page is a dictionary page. Just + // return the fact there is no dictionary page. + if self.reader.remaining_len() == 0 { + return Ok(None); + } + // a dictionary page exists iff the first data page is not at the start of // the column let seek_offset = self.reader.position(); @@ -161,14 +159,7 @@ impl PageReader { )); } - finish_page( - page_header, - buffer, - self.compression, - &self.descriptor, - None, - ) - .map(|p| { + finish_page(page_header, buffer, self.compression, &self.descriptor).map(|p| { if let CompressedPage::Dict(d) = p { Some(d) } else { @@ -190,16 +181,7 @@ impl Iterator for PageReader { fn next(&mut self) -> Option { let mut buffer = std::mem::take(&mut self.scratch); let maybe_maybe_page = next_page(self).transpose(); - if let Some(ref maybe_page) = maybe_maybe_page { - if let Ok(CompressedPage::Data(page)) = maybe_page { - // check if we should filter it (only valid for data pages) - let to_consume = (self.pages_filter)(&self.descriptor, page.header()); - if !to_consume { - self.scratch = std::mem::take(&mut buffer); - return self.next(); - } - } - } else { + if maybe_maybe_page.is_none() { // no page => we take back the buffer self.scratch = std::mem::take(&mut buffer); } @@ -245,14 +227,7 @@ pub(super) fn build_page(reader: &mut PageReader) -> ParquetResult>, ) -> ParquetResult { let type_ = page_header.type_.try_into()?; let uncompressed_page_size = page_header.uncompressed_page_size.try_into()?; @@ -313,7 +287,6 @@ pub(super) fn finish_page( compression, uncompressed_page_size, descriptor.clone(), - selected_rows, ))) }, PageType::DataPageV2 => { @@ -336,7 +309,6 @@ pub(super) fn finish_page( compression, uncompressed_page_size, descriptor.clone(), - selected_rows, ))) }, } diff --git a/crates/polars-parquet/src/parquet/read/page/stream.rs b/crates/polars-parquet/src/parquet/read/page/stream.rs index 7b89dc3937cd..fbd36b3ccfe1 100644 --- a/crates/polars-parquet/src/parquet/read/page/stream.rs +++ b/crates/polars-parquet/src/parquet/read/page/stream.rs @@ -1,43 +1,32 @@ use std::io::SeekFrom; use async_stream::try_stream; -use futures::io::{copy, sink}; use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream}; use parquet_format_safe::thrift::protocol::TCompactInputStreamProtocol; use polars_utils::mmap::MemSlice; use super::reader::{finish_page, PageMetaData}; -use super::PageFilter; use crate::parquet::compression::Compression; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::metadata::{ColumnChunkMetaData, Descriptor}; +use crate::parquet::metadata::{ColumnChunkMetadata, Descriptor}; use crate::parquet::page::{CompressedPage, DataPageHeader, ParquetPageHeader}; use crate::parquet::parquet_bridge::{Encoding, PageType}; /// Returns a stream of compressed data pages pub async fn get_page_stream<'a, RR: AsyncRead + Unpin + Send + AsyncSeek>( - column_metadata: &'a ColumnChunkMetaData, + column_metadata: &'a ColumnChunkMetadata, reader: &'a mut RR, scratch: Vec, - pages_filter: PageFilter, max_page_size: usize, ) -> ParquetResult> + 'a> { - get_page_stream_with_page_meta( - column_metadata.into(), - reader, - scratch, - pages_filter, - max_page_size, - ) - .await + get_page_stream_with_page_meta(column_metadata.into(), reader, scratch, max_page_size).await } /// Returns a stream of compressed data pages from a reader that begins at the start of the column pub async fn get_page_stream_from_column_start<'a, R: AsyncRead + Unpin + Send>( - column_metadata: &'a ColumnChunkMetaData, + column_metadata: &'a ColumnChunkMetadata, reader: &'a mut R, scratch: Vec, - pages_filter: PageFilter, max_header_size: usize, ) -> ParquetResult> + 'a> { let page_metadata: PageMetaData = column_metadata.into(); @@ -47,7 +36,6 @@ pub async fn get_page_stream_from_column_start<'a, R: AsyncRead + Unpin + Send>( page_metadata.compression, page_metadata.descriptor, scratch, - pages_filter, max_header_size, )) } @@ -57,7 +45,6 @@ pub async fn get_page_stream_with_page_meta, - pages_filter: PageFilter, max_page_size: usize, ) -> ParquetResult> + '_> { let column_start = page_metadata.column_start; @@ -68,7 +55,6 @@ pub async fn get_page_stream_with_page_meta( compression: Compression, descriptor: Descriptor, mut scratch: Vec, - pages_filter: PageFilter, max_page_size: usize, ) -> impl Stream> + '_ { let mut seen_values = 0i64; @@ -93,14 +78,6 @@ fn _get_page_stream( let read_size: usize = page_header.compressed_page_size.try_into()?; - if let Some(data_header) = data_header { - if !pages_filter(&descriptor, &data_header) { - // page to be skipped, we sill need to seek - copy(reader.take(read_size as u64), &mut sink()).await?; - continue - } - } - if read_size > max_page_size { Err(ParquetError::WouldOverAllocate)? } @@ -123,7 +100,6 @@ fn _get_page_stream( MemSlice::from_vec(std::mem::take(&mut scratch)), compression, &descriptor, - None, )?; } } diff --git a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs index 36da3d5edcd1..d4f2c692e95d 100644 --- a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs +++ b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs @@ -43,6 +43,7 @@ //! ``` use parquet_format_safe::Type; +use polars_utils::pl_str::PlSmallStr; use types::PrimitiveLogicalType; use super::super::types::{ParquetType, TimeUnit}; @@ -158,9 +159,11 @@ fn type_from_str(s: &str) -> ParquetResult { } } -/// Parses message type as string into a Parquet [`ParquetType`](crate::parquet::schema::types::ParquetType) -/// which, for example, could be used to extract individual columns. Returns Parquet -/// general error when parsing or validation fails. +/// Parses message type as string into a Parquet [`ParquetType`](crate::parquet::schema::types::ParquetType). +/// +/// This could, for example, be used to extract individual columns. +/// +/// Returns Parquet general error when parsing or validation fails. pub fn from_message(message_type: &str) -> ParquetResult { let mut parser = Parser { tokenizer: &mut Tokenizer::from_str(message_type), @@ -311,7 +314,7 @@ impl<'a> Parser<'a> { .next() .ok_or_else(|| ParquetError::oos("Expected name, found None"))?; let fields = self.parse_child_types()?; - Ok(ParquetType::new_root(name.to_string(), fields)) + Ok(ParquetType::new_root(PlSmallStr::from_str(name), fields)) }, _ => Err(ParquetError::oos( "Message type does not start with 'message'", @@ -387,7 +390,7 @@ impl<'a> Parser<'a> { let fields = self.parse_child_types()?; Ok(ParquetType::from_converted( - name.to_string(), + PlSmallStr::from_str(name), fields, repetition, converted_type, @@ -471,7 +474,7 @@ impl<'a> Parser<'a> { assert_token(self.tokenizer.next(), ";")?; ParquetType::try_from_primitive( - name.to_string(), + PlSmallStr::from_str(name), (physical_type, length).try_into()?, repetition, converted_type, @@ -881,7 +884,7 @@ mod tests { let fields = vec![ ParquetType::try_from_primitive( - "f1".to_string(), + PlSmallStr::from_static("f1"), PhysicalType::FixedLenByteArray(5), Repetition::Optional, None, @@ -889,7 +892,7 @@ mod tests { None, )?, ParquetType::try_from_primitive( - "f2".to_string(), + PlSmallStr::from_static("f2"), PhysicalType::FixedLenByteArray(16), Repetition::Optional, None, @@ -898,7 +901,7 @@ mod tests { )?, ]; - let expected = ParquetType::new_root("root".to_string(), fields); + let expected = ParquetType::new_root(PlSmallStr::from_static("root"), fields); assert_eq!(message, expected); Ok(()) @@ -930,7 +933,7 @@ mod tests { .unwrap(); let a2 = ParquetType::try_from_primitive( - "a2".to_string(), + "a2".into(), PhysicalType::ByteArray, Repetition::Repeated, Some(PrimitiveConvertedType::Utf8), @@ -938,38 +941,38 @@ mod tests { None, )?; let a1 = ParquetType::from_converted( - "a1".to_string(), + "a1".into(), vec![a2], Repetition::Optional, Some(GroupConvertedType::List), None, ); let b2 = ParquetType::from_converted( - "b2".to_string(), + "b2".into(), vec![ - ParquetType::from_physical("b3".to_string(), PhysicalType::Int32), - ParquetType::from_physical("b4".to_string(), PhysicalType::Double), + ParquetType::from_physical("b3".into(), PhysicalType::Int32), + ParquetType::from_physical("b4".into(), PhysicalType::Double), ], Repetition::Repeated, None, None, ); let b1 = ParquetType::from_converted( - "b1".to_string(), + "b1".into(), vec![b2], Repetition::Optional, Some(GroupConvertedType::List), None, ); let a0 = ParquetType::from_converted( - "a0".to_string(), + "a0".into(), vec![a1, b1], Repetition::Required, None, None, ); - let expected = ParquetType::new_root("root".to_string(), vec![a0]); + let expected = ParquetType::new_root("root".into(), vec![a0]); assert_eq!(message, expected); Ok(()) @@ -995,7 +998,7 @@ mod tests { .unwrap(); let f1 = ParquetType::try_from_primitive( - "_1".to_string(), + "_1".into(), PhysicalType::Int32, Repetition::Required, Some(PrimitiveConvertedType::Int8), @@ -1003,7 +1006,7 @@ mod tests { None, )?; let f2 = ParquetType::try_from_primitive( - "_2".to_string(), + "_2".into(), PhysicalType::Int32, Repetition::Required, Some(PrimitiveConvertedType::Int16), @@ -1011,7 +1014,7 @@ mod tests { None, )?; let f3 = ParquetType::try_from_primitive( - "_3".to_string(), + "_3".into(), PhysicalType::Float, Repetition::Required, None, @@ -1019,7 +1022,7 @@ mod tests { None, )?; let f4 = ParquetType::try_from_primitive( - "_4".to_string(), + "_4".into(), PhysicalType::Double, Repetition::Required, None, @@ -1027,7 +1030,7 @@ mod tests { None, )?; let f5 = ParquetType::try_from_primitive( - "_5".to_string(), + "_5".into(), PhysicalType::Int32, Repetition::Optional, None, @@ -1035,7 +1038,7 @@ mod tests { None, )?; let f6 = ParquetType::try_from_primitive( - "_6".to_string(), + "_6".into(), PhysicalType::ByteArray, Repetition::Optional, Some(PrimitiveConvertedType::Utf8), @@ -1045,7 +1048,7 @@ mod tests { let fields = vec![f1, f2, f3, f4, f5, f6]; - let expected = ParquetType::new_root("root".to_string(), fields); + let expected = ParquetType::new_root("root".into(), fields); assert_eq!(message, expected); Ok(()) } @@ -1073,7 +1076,7 @@ mod tests { .parse_message_type()?; let f1 = ParquetType::try_from_primitive( - "_1".to_string(), + "_1".into(), PhysicalType::Int32, Repetition::Required, None, @@ -1081,7 +1084,7 @@ mod tests { None, )?; let f2 = ParquetType::try_from_primitive( - "_2".to_string(), + "_2".into(), PhysicalType::Int32, Repetition::Required, None, @@ -1089,7 +1092,7 @@ mod tests { None, )?; let f3 = ParquetType::try_from_primitive( - "_3".to_string(), + "_3".into(), PhysicalType::Float, Repetition::Required, None, @@ -1097,7 +1100,7 @@ mod tests { None, )?; let f4 = ParquetType::try_from_primitive( - "_4".to_string(), + "_4".into(), PhysicalType::Double, Repetition::Required, None, @@ -1105,7 +1108,7 @@ mod tests { None, )?; let f5 = ParquetType::try_from_primitive( - "_5".to_string(), + "_5".into(), PhysicalType::Int32, Repetition::Optional, None, @@ -1113,7 +1116,7 @@ mod tests { None, )?; let f6 = ParquetType::try_from_primitive( - "_6".to_string(), + "_6".into(), PhysicalType::Int32, Repetition::Optional, None, @@ -1124,7 +1127,7 @@ mod tests { None, )?; let f7 = ParquetType::try_from_primitive( - "_7".to_string(), + "_7".into(), PhysicalType::Int64, Repetition::Optional, None, @@ -1135,7 +1138,7 @@ mod tests { None, )?; let f8 = ParquetType::try_from_primitive( - "_8".to_string(), + "_8".into(), PhysicalType::Int64, Repetition::Optional, None, @@ -1146,7 +1149,7 @@ mod tests { None, )?; let f9 = ParquetType::try_from_primitive( - "_9".to_string(), + "_9".into(), PhysicalType::Int64, Repetition::Optional, None, @@ -1158,7 +1161,7 @@ mod tests { )?; let f10 = ParquetType::try_from_primitive( - "_10".to_string(), + "_10".into(), PhysicalType::ByteArray, Repetition::Optional, None, @@ -1168,7 +1171,7 @@ mod tests { let fields = vec![f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]; - let expected = ParquetType::new_root("root".to_string(), fields); + let expected = ParquetType::new_root("root".into(), fields); assert_eq!(message, expected); Ok(()) } diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs index b4c0733df769..b0bbe20999bc 100644 --- a/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/from_thrift.rs @@ -1,4 +1,5 @@ use parquet_format_safe::SchemaElement; +use polars_utils::pl_str::PlSmallStr; use super::super::types::ParquetType; use crate::parquet::error::{ParquetError, ParquetResult}; @@ -40,11 +41,18 @@ fn from_thrift_helper( let element = elements.get(index).ok_or_else(|| { ParquetError::oos(format!("index {} on SchemaElement is not valid", index)) })?; - let name = element.name.clone(); + let name = PlSmallStr::from_str(element.name.as_str()); let converted_type = element.converted_type; let id = element.field_id; match element.num_children { + // empty root + None | Some(0) if is_root_node => { + let fields = vec![]; + let tp = ParquetType::new_root(name, fields); + Ok((index + 1, tp)) + }, + // From parquet-format: // The children count is used to construct the nested relationship. // This field is not set when the element is a primitive type diff --git a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs index 27c9d886b2ef..3aef1fe792fa 100644 --- a/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs +++ b/crates/polars-parquet/src/parquet/schema/io_thrift/to_thrift.rs @@ -32,7 +32,7 @@ fn to_thrift_helper(schema: &ParquetType, elements: &mut Vec, is_ type_: Some(type_), type_length, repetition_type: Some(field_info.repetition.into()), - name: field_info.name.clone(), + name: field_info.name.to_string(), num_children: None, converted_type, precision: maybe_decimal.map(|x| x.0), @@ -62,7 +62,7 @@ fn to_thrift_helper(schema: &ParquetType, elements: &mut Vec, is_ type_: None, type_length: None, repetition_type: repetition_type.map(|x| x.into()), - name: field_info.name.clone(), + name: field_info.name.to_string(), num_children: Some(fields.len() as i32), converted_type, scale: None, diff --git a/crates/polars-parquet/src/parquet/schema/types/basic_type.rs b/crates/polars-parquet/src/parquet/schema/types/basic_type.rs index b3697fcaa1c3..e882f83516f5 100644 --- a/crates/polars-parquet/src/parquet/schema/types/basic_type.rs +++ b/crates/polars-parquet/src/parquet/schema/types/basic_type.rs @@ -1,3 +1,4 @@ +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; @@ -8,7 +9,7 @@ use super::super::Repetition; #[cfg_attr(feature = "serde_types", derive(Deserialize, Serialize))] pub struct FieldInfo { /// The field name - pub name: String, + pub name: PlSmallStr, /// The repetition pub repetition: Repetition, /// the optional id, to select fields by id diff --git a/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs b/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs index c5c5642eb1c6..ad703cc884a3 100644 --- a/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs +++ b/crates/polars-parquet/src/parquet/schema/types/parquet_type.rs @@ -1,5 +1,6 @@ // see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md use polars_utils::aliases::*; +use polars_utils::pl_str::PlSmallStr; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; @@ -26,7 +27,7 @@ pub struct PrimitiveType { impl PrimitiveType { /// Helper method to create an optional field with no logical or converted types. - pub fn from_physical(name: String, physical_type: PhysicalType) -> Self { + pub fn from_physical(name: PlSmallStr, physical_type: PhysicalType) -> Self { let field_info = FieldInfo { name, repetition: Repetition::Optional, @@ -114,7 +115,7 @@ impl ParquetType { /// Constructors impl ParquetType { - pub(crate) fn new_root(name: String, fields: Vec) -> Self { + pub(crate) fn new_root(name: PlSmallStr, fields: Vec) -> Self { let field_info = FieldInfo { name, repetition: Repetition::Optional, @@ -129,7 +130,7 @@ impl ParquetType { } pub fn from_converted( - name: String, + name: PlSmallStr, fields: Vec, repetition: Repetition, converted_type: Option, @@ -152,7 +153,7 @@ impl ParquetType { /// # Error /// Errors iff the combination of physical, logical and converted type is not valid. pub fn try_from_primitive( - name: String, + name: PlSmallStr, physical_type: PhysicalType, repetition: Repetition, converted_type: Option, @@ -178,12 +179,12 @@ impl ParquetType { /// Helper method to create a [`ParquetType::PrimitiveType`] optional field /// with no logical or converted types. - pub fn from_physical(name: String, physical_type: PhysicalType) -> Self { + pub fn from_physical(name: PlSmallStr, physical_type: PhysicalType) -> Self { ParquetType::PrimitiveType(PrimitiveType::from_physical(name, physical_type)) } pub fn from_group( - name: String, + name: PlSmallStr, repetition: Repetition, converted_type: Option, logical_type: Option, diff --git a/crates/polars-parquet/src/parquet/write/column_chunk.rs b/crates/polars-parquet/src/parquet/write/column_chunk.rs index 3a5a9a504d9c..6ae51a191dc5 100644 --- a/crates/polars-parquet/src/parquet/write/column_chunk.rs +++ b/crates/polars-parquet/src/parquet/write/column_chunk.rs @@ -179,7 +179,11 @@ fn build_column_chunk( let metadata = ColumnMetaData { type_, encodings, - path_in_schema: descriptor.path_in_schema.clone(), + path_in_schema: descriptor + .path_in_schema + .iter() + .map(|x| x.to_string()) + .collect::>(), codec: compression.into(), num_values, total_uncompressed_size, diff --git a/crates/polars-parquet/src/parquet/write/compression.rs b/crates/polars-parquet/src/parquet/write/compression.rs index 1c7d4d36a901..04d01a6e34bc 100644 --- a/crates/polars-parquet/src/parquet/write/compression.rs +++ b/crates/polars-parquet/src/parquet/write/compression.rs @@ -16,9 +16,10 @@ fn compress_data( mut buffer, header, descriptor, - selected_rows, + num_rows, } = page; let uncompressed_page_size = buffer.len(); + let num_rows = num_rows.expect("We should have num_rows when we are writing"); if compression != CompressionOptions::Uncompressed { match &header { DataPageHeader::V1(_) => { @@ -40,13 +41,13 @@ fn compress_data( std::mem::swap(buffer.to_mut(), &mut compressed_buffer); } - Ok(CompressedDataPage::new_read( + Ok(CompressedDataPage::new( header, CowBuffer::Owned(compressed_buffer), compression.into(), uncompressed_page_size, descriptor, - selected_rows, + num_rows, )) } diff --git a/crates/polars-parquet/src/parquet/write/indexes/serialize.rs b/crates/polars-parquet/src/parquet/write/indexes/serialize.rs index 8b3cebec1686..14594bc2b8c4 100644 --- a/crates/polars-parquet/src/parquet/write/indexes/serialize.rs +++ b/crates/polars-parquet/src/parquet/write/indexes/serialize.rs @@ -62,11 +62,7 @@ pub fn serialize_offset_index(pages: &[PageWriteSpec]) -> ParquetResult, + /// The number of actual rows. For non-nested values, this is equal to the number of values. + pub num_rows: usize, pub header_size: u64, pub offset: u64, pub bytes_written: u64, @@ -55,7 +56,9 @@ pub fn write_page( compressed_page: &CompressedPage, ) -> ParquetResult { let num_values = compressed_page.num_values(); - let selected_rows = compressed_page.selected_rows(); + let num_rows = compressed_page + .num_rows() + .expect("We should have num_rows when we are writing"); let header = match &compressed_page { CompressedPage::Data(compressed_page) => assemble_data_page_header(compressed_page), @@ -88,8 +91,8 @@ pub fn write_page( bytes_written, compression: compressed_page.compression(), statistics, - num_rows: selected_rows.map(|x| x.last().unwrap().length), num_values, + num_rows, }) } @@ -101,7 +104,9 @@ pub async fn write_page_async( compressed_page: &CompressedPage, ) -> ParquetResult { let num_values = compressed_page.num_values(); - let selected_rows = compressed_page.selected_rows(); + let num_rows = compressed_page + .num_rows() + .expect("We should have the num_rows when we are writing"); let header = match &compressed_page { CompressedPage::Data(compressed_page) => assemble_data_page_header(compressed_page), @@ -134,7 +139,7 @@ pub async fn write_page_async( bytes_written, compression: compressed_page.compression(), statistics, - num_rows: selected_rows.map(|x| x.last().unwrap().length), + num_rows, num_values, }) } diff --git a/crates/polars-parquet/src/parquet/write/row_group.rs b/crates/polars-parquet/src/parquet/write/row_group.rs index e5c535055ea6..43404dc32a89 100644 --- a/crates/polars-parquet/src/parquet/write/row_group.rs +++ b/crates/polars-parquet/src/parquet/write/row_group.rs @@ -10,7 +10,7 @@ use super::column_chunk::write_column_chunk_async; use super::page::{is_data_page, PageWriteSpec}; use super::{DynIter, DynStreamingIterator}; use crate::parquet::error::{ParquetError, ParquetResult}; -use crate::parquet::metadata::{ColumnChunkMetaData, ColumnDescriptor}; +use crate::parquet::metadata::{ColumnChunkMetadata, ColumnDescriptor}; use crate::parquet::page::CompressedPage; pub struct ColumnOffsetsMetadata { @@ -34,7 +34,7 @@ impl ColumnOffsetsMetadata { } pub fn from_column_chunk_metadata( - column_chunk_metadata: &ColumnChunkMetaData, + column_chunk_metadata: &ColumnChunkMetadata, ) -> ColumnOffsetsMetadata { ColumnOffsetsMetadata { dictionary_page_offset: column_chunk_metadata.dictionary_page_offset(), @@ -58,9 +58,7 @@ fn compute_num_rows(columns: &[(ColumnChunk, Vec)]) -> ParquetRes .iter() .filter(|x| is_data_page(x)) .try_for_each(|spec| { - num_rows += spec.num_rows.ok_or_else(|| { - ParquetError::oos("All data pages must declare the number of rows on it") - })? as i64; + num_rows += spec.num_rows as i64; ParquetResult::Ok(()) })?; ParquetResult::Ok(num_rows) diff --git a/crates/polars-parquet/src/parquet/write/statistics.rs b/crates/polars-parquet/src/parquet/write/statistics.rs index d37256d3ca1e..064a16eb931b 100644 --- a/crates/polars-parquet/src/parquet/write/statistics.rs +++ b/crates/polars-parquet/src/parquet/write/statistics.rs @@ -41,7 +41,7 @@ pub fn reduce(stats: &[&Option]) -> ParquetResult .all(|x| x.physical_type() == stats[0].physical_type()); if !same_type { return Err(ParquetError::oos( - "The statistics do not have the same data_type", + "The statistics do not have the same dtype", )); }; @@ -164,20 +164,14 @@ mod tests { fn binary() -> ParquetResult<()> { let iter = vec![ BinaryStatistics { - primitive_type: PrimitiveType::from_physical( - "bla".to_string(), - PhysicalType::ByteArray, - ), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::ByteArray), null_count: Some(0), distinct_count: None, min_value: Some(vec![1, 2]), max_value: Some(vec![3, 4]), }, BinaryStatistics { - primitive_type: PrimitiveType::from_physical( - "bla".to_string(), - PhysicalType::ByteArray, - ), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::ByteArray), null_count: Some(0), distinct_count: None, min_value: Some(vec![4, 5]), @@ -189,10 +183,7 @@ mod tests { assert_eq!( a, BinaryStatistics { - primitive_type: PrimitiveType::from_physical( - "bla".to_string(), - PhysicalType::ByteArray, - ), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::ByteArray,), null_count: Some(0), distinct_count: None, min_value: Some(vec![1, 2]), @@ -208,7 +199,7 @@ mod tests { let iter = vec![ FixedLenStatistics { primitive_type: PrimitiveType::from_physical( - "bla".to_string(), + "bla".into(), PhysicalType::FixedLenByteArray(2), ), null_count: Some(0), @@ -218,7 +209,7 @@ mod tests { }, FixedLenStatistics { primitive_type: PrimitiveType::from_physical( - "bla".to_string(), + "bla".into(), PhysicalType::FixedLenByteArray(2), ), null_count: Some(0), @@ -233,7 +224,7 @@ mod tests { a, FixedLenStatistics { primitive_type: PrimitiveType::from_physical( - "bla".to_string(), + "bla".into(), PhysicalType::FixedLenByteArray(2), ), null_count: Some(0), @@ -284,7 +275,7 @@ mod tests { distinct_count: None, min_value: Some(30), max_value: Some(70), - primitive_type: PrimitiveType::from_physical("bla".to_string(), PhysicalType::Int32), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::Int32), }]; let a = reduce_primitive(iter.iter()); @@ -295,10 +286,7 @@ mod tests { distinct_count: None, min_value: Some(30), max_value: Some(70), - primitive_type: PrimitiveType::from_physical( - "bla".to_string(), - PhysicalType::Int32, - ), + primitive_type: PrimitiveType::from_physical("bla".into(), PhysicalType::Int32,), }, ); diff --git a/crates/polars-pipe/Cargo.toml b/crates/polars-pipe/Cargo.toml index b11a43bbaae8..cec5c8484285 100644 --- a/crates/polars-pipe/Cargo.toml +++ b/crates/polars-pipe/Cargo.toml @@ -28,7 +28,6 @@ enum_dispatch = { version = "0.3" } hashbrown = { workspace = true } num-traits = { workspace = true } rayon = { workspace = true } -smartstring = { workspace = true } [build-dependencies] version_check = { workspace = true } diff --git a/crates/polars-pipe/src/executors/operators/function.rs b/crates/polars-pipe/src/executors/operators/function.rs index 6501536f6ee1..de17b9d0c15a 100644 --- a/crates/polars-pipe/src/executors/operators/function.rs +++ b/crates/polars-pipe/src/executors/operators/function.rs @@ -13,11 +13,11 @@ pub struct FunctionOperator { n_threads: usize, chunk_size: usize, offsets: VecDeque<(usize, usize)>, - function: FunctionNode, + function: FunctionIR, } impl FunctionOperator { - pub(crate) fn new(function: FunctionNode) -> Self { + pub(crate) fn new(function: FunctionIR) -> Self { FunctionOperator { n_threads: POOL.current_num_threads(), function, diff --git a/crates/polars-pipe/src/executors/operators/projection.rs b/crates/polars-pipe/src/executors/operators/projection.rs index f609a592f8da..67141d0c44a7 100644 --- a/crates/polars-pipe/src/executors/operators/projection.rs +++ b/crates/polars-pipe/src/executors/operators/projection.rs @@ -4,19 +4,19 @@ use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_core::schema::SchemaRef; use polars_plan::prelude::ProjectionOptions; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use crate::expressions::PhysicalPipedExpr; use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; #[derive(Clone)] pub(crate) struct SimpleProjectionOperator { - columns: Arc<[SmartString]>, + columns: Arc<[PlSmallStr]>, input_schema: SchemaRef, } impl SimpleProjectionOperator { - pub(crate) fn new(columns: Arc<[SmartString]>, input_schema: SchemaRef) -> Self { + pub(crate) fn new(columns: Arc<[PlSmallStr]>, input_schema: SchemaRef) -> Self { Self { columns, input_schema, @@ -30,11 +30,12 @@ impl Operator for SimpleProjectionOperator { _context: &PExecutionContext, chunk: &DataChunk, ) -> PolarsResult { - let chunk = chunk.with_data( - chunk - .data - .select_with_schema_unchecked(self.columns.as_ref(), &self.input_schema)?, - ); + let check_duplicates = false; + let chunk = chunk.with_data(chunk.data._select_with_schema_impl( + self.columns.as_ref(), + &self.input_schema, + check_duplicates, + )?); Ok(OperatorResult::Finished(chunk)) } fn split(&self, _thread_no: usize) -> Box { diff --git a/crates/polars-pipe/src/executors/operators/reproject.rs b/crates/polars-pipe/src/executors/operators/reproject.rs index ca2bd5cb1e78..a4f6010bef79 100644 --- a/crates/polars-pipe/src/executors/operators/reproject.rs +++ b/crates/polars-pipe/src/executors/operators/reproject.rs @@ -16,7 +16,7 @@ pub(crate) fn reproject_chunk( let out = chunk .data - .select_with_schema_unchecked(schema.iter_names(), &chunk_schema)?; + .select_with_schema_unchecked(schema.iter_names_cloned(), &chunk_schema)?; *positions = out .get_columns() diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs index 1603de3729fa..4e81e8531bac 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs @@ -12,6 +12,7 @@ use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{ArenaExprIter, Context}; use polars_plan::prelude::{AExpr, IRAggExpr}; use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; use crate::executors::sinks::group_by::aggregates::count::CountAgg; @@ -31,7 +32,7 @@ impl PhysicalIoExpr for Len { unimplemented!() } - fn live_variables(&self) -> Option>> { + fn live_variables(&self) -> Option> { Some(vec![]) } } @@ -39,7 +40,7 @@ impl PhysicalPipedExpr for Len { fn evaluate(&self, chunk: &DataChunk, _lazy_state: &ExecutionState) -> PolarsResult { // the length must match the chunks as the operators expect that // so we fill a null series. - Ok(Series::new_null("", chunk.data.height())) + Ok(Series::new_null(PlSmallStr::EMPTY, chunk.data.height())) } fn field(&self, _input_schema: &Schema) -> PolarsResult { diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs index 84e504816daa..f81366a34641 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs @@ -1,6 +1,5 @@ use std::cell::UnsafeCell; -use polars_core::export::ahash::RandomState; use polars_row::{EncodingField, RowsEncoded}; use super::*; @@ -13,7 +12,7 @@ pub(super) struct Eval { key_columns_expr: Arc>>, // the columns that will be aggregated aggregation_columns_expr: Arc>>, - hb: RandomState, + hb: PlRandomState, // amortize allocations aggregation_series: UnsafeCell>, keys_columns: UnsafeCell>, @@ -28,7 +27,7 @@ impl Eval { key_columns: Arc>>, aggregation_columns: Arc>>, ) -> Self { - let hb = RandomState::default(); + let hb = PlRandomState::default(); Self { key_columns_expr: key_columns, aggregation_columns_expr: aggregation_columns, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs index 2bb4f57b46a1..3e57db331b3e 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs @@ -259,7 +259,7 @@ impl AggHashTable { let key_dtypes = self .output_schema - .iter_dtypes() + .iter_values() .take(self.num_keys) .map(|dtype| dtype.to_physical().to_arrow(CompatLevel::newest())) .collect::>(); @@ -271,7 +271,7 @@ impl AggHashTable { cols.extend( key_columns .into_iter() - .map(|arr| Series::try_from(("", arr)).unwrap()), + .map(|arr| Series::try_from((PlSmallStr::EMPTY, arr)).unwrap()), ); cols.extend(agg_builders.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs index 41967ee85426..55244679e204 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs @@ -65,7 +65,7 @@ impl SpillPayload { schema.with_column(INDEX_COL.into(), IDX_DTYPE); schema.with_column(KEYS_COL.into(), DataType::BinaryOffset); for s in &self.aggs { - schema.with_column(s.name().into(), s.dtype().clone()); + schema.with_column(s.name().clone(), s.dtype().clone()); } schema } @@ -74,9 +74,12 @@ impl SpillPayload { debug_assert_eq!(self.hashes.len(), self.chunk_idx.len()); debug_assert_eq!(self.hashes.len(), self.keys.len()); - let hashes = UInt64Chunked::from_vec(HASH_COL, self.hashes).into_series(); - let chunk_idx = IdxCa::from_vec(INDEX_COL, self.chunk_idx).into_series(); - let keys = BinaryOffsetChunked::with_chunk(KEYS_COL, self.keys).into_series(); + let hashes = + UInt64Chunked::from_vec(PlSmallStr::from_static(HASH_COL), self.hashes).into_series(); + let chunk_idx = + IdxCa::from_vec(PlSmallStr::from_static(INDEX_COL), self.chunk_idx).into_series(); + let keys = BinaryOffsetChunked::with_chunk(PlSmallStr::from_static(KEYS_COL), self.keys) + .into_series(); let mut cols = Vec::with_capacity(self.aggs.len() + 3); cols.push(hashes); diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs index dd3231d5af7d..50a68cd34d27 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/sink.rs @@ -28,7 +28,7 @@ impl GenericGroupby2 { ) -> Self { let key_dtypes: Arc<[DataType]> = Arc::from( output_schema - .iter_dtypes() + .iter_values() .take(key_columns.len()) .cloned() .collect::>(), diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs index 3554c24c7e65..e9edd3b22f25 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/thread_local.rs @@ -139,7 +139,7 @@ impl SpillPartitions { .zip(self.output_schema.iter_names()) .map(|(b, name)| { let mut s = b.reset(OB_SIZE); - s.rename(name); + s.rename(name.clone()); s }) .collect(), diff --git a/crates/polars-pipe/src/executors/sinks/group_by/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/mod.rs index c2eaafe39d76..7a999e7e7cc7 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/mod.rs @@ -16,7 +16,7 @@ pub(crate) use string::*; pub(super) fn physical_agg_to_logical(cols: &mut [Series], output_schema: &Schema) { for (s, (name, dtype)) in cols.iter_mut().zip(output_schema.iter()) { if s.name() != name { - s.rename(name); + s.rename(name.clone()); } match dtype { #[cfg(feature = "dtype-categorical")] diff --git a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs index ecc0c9f09c68..d20ab9bf2b0d 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs @@ -7,7 +7,6 @@ use arrow::legacy::is_valid::IsValid; use arrow::legacy::kernels::sort_partition::partition_to_groups_amortized; use hashbrown::hash_map::RawEntryMut; use num_traits::NumCast; -use polars_core::export::ahash::RandomState; use polars_core::frame::row::AnyValueBuffer; use polars_core::prelude::*; use polars_core::series::IsSorted; @@ -62,7 +61,7 @@ pub struct PrimitiveGroupbySink { key: Arc, // the columns that will be aggregated aggregation_columns: Arc>>, - hb: RandomState, + hb: PlRandomState, // Initializing Aggregation functions. If we aggregate by 2 columns // this vec will have two functions. We will use these functions // to populate the buffer where the hashmap points to @@ -116,7 +115,7 @@ where io_thread: Option>>>, ooc: bool, ) -> Self { - let hb = RandomState::default(); + let hb = PlRandomState::default(); let partitions = _set_partition_size(); let pre_agg = load_vec(partitions, || PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE)); @@ -174,7 +173,7 @@ where let agg_fns = unsafe { std::slice::from_raw_parts_mut(ptr, aggregators_len) }; let mut key_builder = PrimitiveChunkedBuilder::::new( - self.output_schema.get_at_index(0).unwrap().0, + self.output_schema.get_at_index(0).unwrap().0.clone(), agg_map.len(), ); let dtypes = agg_fns diff --git a/crates/polars-pipe/src/executors/sinks/group_by/string.rs b/crates/polars-pipe/src/executors/sinks/group_by/string.rs index c7369b5dd110..0a66255f6e7c 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/string.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/string.rs @@ -4,7 +4,6 @@ use std::sync::Mutex; use hashbrown::hash_map::RawEntryMut; use num_traits::NumCast; -use polars_core::export::ahash::RandomState; use polars_core::frame::row::AnyValueBuffer; use polars_core::prelude::*; use polars_core::utils::_set_partition_size; @@ -62,13 +61,13 @@ pub struct StringGroupbySink { // by: // * offset = (idx) // * end = (offset + 1) - keys: Vec>, + keys: Vec>, aggregators: Vec, // the key that will be aggregated on key_column: Arc, // the columns that will be aggregated aggregation_columns: Arc>>, - hb: RandomState, + hb: PlRandomState, // Initializing Aggregation functions. If we aggregate by 2 columns // this vec will have two functions. We will use these functions // to populate the buffer where the hashmap points to @@ -187,7 +186,7 @@ impl StringGroupbySink { .collect::>(); let cap = std::cmp::min(slice_len, agg_map.len()); - let mut key_builder = StringChunkedBuilder::new("", cap); + let mut key_builder = StringChunkedBuilder::new(PlSmallStr::EMPTY, cap); agg_map.into_iter().skip(offset).take(slice_len).for_each( |(k, &offset)| { let key_offset = k.idx as usize; @@ -583,7 +582,7 @@ fn get_entry<'a>( key_val: Option<&str>, h: u64, current_partition: &'a mut PlIdHashMap, - keys: &[Option], + keys: &[Option], ) -> RawEntryMut<'a, Key, IdxSize, IdBuildHasher> { current_partition.raw_entry_mut().from_hash(h, |key| { // first compare the hash before we incur the cache miss diff --git a/crates/polars-pipe/src/executors/sinks/io.rs b/crates/polars-pipe/src/executors/sinks/io.rs index 15cc6f8f2537..34d357ee18a4 100644 --- a/crates/polars-pipe/src/executors/sinks/io.rs +++ b/crates/polars-pipe/src/executors/sinks/io.rs @@ -170,6 +170,7 @@ impl IOThread { if let Some(partitions) = partitions { for (part, mut df) in partitions.into_no_null_iter().zip(iter) { df.shrink_to_fit(); + df.align_chunks(); let mut path = dir2.clone(); path.push(format!("{part}")); @@ -193,6 +194,7 @@ impl IOThread { for mut df in iter { df.shrink_to_fit(); + df.align_chunks(); writer.write_batch(&df).unwrap(); } writer.finish().unwrap(); @@ -240,7 +242,7 @@ impl IOThread { } pub(in crate::executors::sinks) fn dump_partition(&self, partition_no: IdxSize, df: DataFrame) { - let partition = Some(IdxCa::from_vec("", vec![partition_no])); + let partition = Some(IdxCa::from_vec(PlSmallStr::EMPTY, vec![partition_no])); let iter = Box::new(std::iter::once(df)); self.dump_iter(partition, iter) } diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs index bac7c8243139..d6014c344978 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/cross.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -8,7 +8,7 @@ use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_ops::prelude::CrossJoin as CrossJoinTrait; use polars_utils::arena::Node; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use crate::executors::operators::PlaceHolder; use crate::operators::{ @@ -19,7 +19,7 @@ use crate::operators::{ #[derive(Default)] pub struct CrossJoin { chunks: Vec, - suffix: SmartString, + suffix: PlSmallStr, swapped: bool, node: Node, placeholder: PlaceHolder, @@ -27,7 +27,7 @@ pub struct CrossJoin { impl CrossJoin { pub(crate) fn new( - suffix: SmartString, + suffix: PlSmallStr, swapped: bool, node: Node, placeholder: PlaceHolder, @@ -73,7 +73,7 @@ impl Sink for CrossJoin { fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { let op = Box::new(CrossJoinProbe { df: Arc::new(chunks_to_df_unchecked(std::mem::take(&mut self.chunks))), - suffix: Arc::from(self.suffix.as_ref()), + suffix: self.suffix.clone(), in_process_left: None, in_process_right: None, in_process_left_df: Default::default(), @@ -97,11 +97,11 @@ impl Sink for CrossJoin { #[derive(Clone)] pub struct CrossJoinProbe { df: Arc, - suffix: Arc, + suffix: PlSmallStr, in_process_left: Option>>, in_process_right: Option>>, in_process_left_df: DataFrame, - output_names: Option>, + output_names: Option>, swapped: bool, } @@ -159,7 +159,7 @@ impl Operator for CrossJoinProbe { (&self.in_process_left_df, &right_df) }; - let mut df = a.cross_join(b, Some(self.suffix.as_ref()), None)?; + let mut df = a.cross_join(b, Some(self.suffix.clone()), None)?; // Cross joins can produce multiple chunks. // No parallelize in operators df.as_single_chunk(); @@ -183,7 +183,7 @@ impl Operator for CrossJoinProbe { // this we can amortize the name allocations. let mut df = match &self.output_names { None => { - let df = a.cross_join(b, Some(self.suffix.as_ref()), None)?; + let df = a.cross_join(b, Some(self.suffix.clone()), None)?; self.output_names = Some(df.get_column_names_owned()); df }, diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index 8d1145bc6add..9703988e1eb3 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -2,14 +2,13 @@ use std::any::Any; use arrow::array::BinaryArray; use hashbrown::hash_map::RawEntryMut; -use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; use polars_ops::prelude::JoinArgs; use polars_utils::arena::Node; +use polars_utils::pl_str::PlSmallStr; use polars_utils::slice::GetSaferUnchecked; use polars_utils::unitvec; -use smartstring::alias::String as SmartString; use super::*; use crate::executors::operators::PlaceHolder; @@ -33,8 +32,8 @@ pub struct GenericBuild { // * chunk_offset = (idx * n_join_keys) // * end = (offset + n_join_keys) materialized_join_cols: Vec>, - suffix: Arc, - hb: RandomState, + suffix: PlSmallStr, + hb: PlRandomState, join_args: JoinArgs, // partitioned tables that will be used for probing // stores the key and the chunk_idx, df_idx of the left table @@ -51,26 +50,26 @@ pub struct GenericBuild { swapped: bool, join_nulls: bool, node: Node, - key_names_left: Arc<[SmartString]>, - key_names_right: Arc<[SmartString]>, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, placeholder: PlaceHolder, } impl GenericBuild { #[allow(clippy::too_many_arguments)] pub(crate) fn new( - suffix: Arc, + suffix: PlSmallStr, join_args: JoinArgs, swapped: bool, join_columns_left: Arc>>, join_columns_right: Arc>>, join_nulls: bool, node: Node, - key_names_left: Arc<[SmartString]>, - key_names_right: Arc<[SmartString]>, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, placeholder: PlaceHolder, ) -> Self { - let hb: RandomState = Default::default(); + let hb: PlRandomState = Default::default(); let partitions = _set_partition_size(); let hash_tables = PartitionedHashMap::new(load_vec(partitions, || { PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE) diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs index 4201810550ee..5337d517cb79 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs @@ -1,14 +1,13 @@ use std::borrow::Cow; use arrow::array::{Array, BinaryArray}; -use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_ops::chunked_array::DfTake; use polars_ops::frame::join::_finish_join; use polars_ops::prelude::{JoinArgs, JoinType}; use polars_utils::nulls::IsNull; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use crate::executors::sinks::joins::generic_build::*; use crate::executors::sinks::joins::row_values::RowValues; @@ -30,8 +29,8 @@ pub struct GenericJoinProbe { /// * chunk_offset = (idx * n_join_keys) /// * end = (offset + n_join_keys) materialized_join_cols: Arc<[BinaryArray]>, - suffix: Arc, - hb: RandomState, + suffix: PlSmallStr, + hb: PlRandomState, /// partitioned tables that will be used for probing /// stores the key and the chunk_idx, df_idx of the left table hash_tables: Arc>, @@ -47,7 +46,7 @@ pub struct GenericJoinProbe { /// the join order is swapped to ensure we hash the smaller table swapped_or_left: bool, /// cached output names - output_names: Option>, + output_names: Option>, args: JoinArgs, join_nulls: bool, row_values: RowValues, @@ -58,8 +57,8 @@ impl GenericJoinProbe { pub(super) fn new( mut df_a: DataFrame, materialized_join_cols: Arc<[BinaryArray]>, - suffix: Arc, - hb: RandomState, + suffix: PlSmallStr, + hb: PlRandomState, hash_tables: Arc>, join_columns_left: Arc>>, join_columns_right: Arc>>, @@ -84,10 +83,10 @@ impl GenericJoinProbe { phys_e .evaluate(&tmp, &context.execution_state) .ok() - .map(|s| s.name().to_string()) + .map(|s| s.name().clone()) }) - .collect::>(); - df_a = df_a.drop_many(&names) + .collect::>(); + df_a = df_a.drop_many_amortized(&names) } GenericJoinProbe { @@ -114,7 +113,7 @@ impl GenericJoinProbe { ) -> PolarsResult { Ok(match &self.output_names { None => { - let out = _finish_join(left_df, right_df, Some(self.suffix.as_ref()))?; + let out = _finish_join(left_df, right_df, Some(self.suffix.clone()))?; self.output_names = Some(out.get_column_names_owned()); out }, @@ -130,7 +129,7 @@ impl GenericJoinProbe { .iter_mut() .zip(names) .for_each(|(s, name)| { - s.rename(name); + s.rename(name.clone()); }); left_df }, diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs index 3ae57fa929a0..0157fe660de5 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs @@ -1,13 +1,12 @@ use std::sync::atomic::Ordering; use arrow::array::{Array, BinaryArray, MutablePrimitiveArray}; -use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_ops::chunked_array::DfTake; use polars_ops::frame::join::_finish_join; use polars_ops::prelude::_coalesce_full_join; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use crate::executors::sinks::joins::generic_build::*; use crate::executors::sinks::joins::row_values::RowValues; @@ -32,8 +31,8 @@ pub struct GenericFullOuterJoinProbe { /// * chunk_offset = (idx * n_join_keys) /// * end = (offset + n_join_keys) materialized_join_cols: Arc<[BinaryArray]>, - suffix: Arc, - hb: RandomState, + suffix: PlSmallStr, + hb: PlRandomState, /// partitioned tables that will be used for probing. /// stores the key and the chunk_idx, df_idx of the left table. hash_tables: Arc>, @@ -49,13 +48,13 @@ pub struct GenericFullOuterJoinProbe { // the join order is swapped to ensure we hash the smaller table swapped: bool, // cached output names - output_names: Option>, + output_names: Option>, join_nulls: bool, coalesce: bool, thread_no: usize, row_values: RowValues, - key_names_left: Arc<[SmartString]>, - key_names_right: Arc<[SmartString]>, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, } impl GenericFullOuterJoinProbe { @@ -63,8 +62,8 @@ impl GenericFullOuterJoinProbe { pub(super) fn new( df_a: DataFrame, materialized_join_cols: Arc<[BinaryArray]>, - suffix: Arc, - hb: RandomState, + suffix: PlSmallStr, + hb: PlRandomState, hash_tables: Arc>, join_columns_right: Arc>>, swapped: bool, @@ -72,8 +71,8 @@ impl GenericFullOuterJoinProbe { amortized_hashes: Vec, join_nulls: bool, coalesce: bool, - key_names_left: Arc<[SmartString]>, - key_names_right: Arc<[SmartString]>, + key_names_left: Arc<[PlSmallStr]>, + key_names_right: Arc<[PlSmallStr]>, ) -> Self { GenericFullOuterJoinProbe { df_a: Arc::new(df_a), @@ -100,9 +99,9 @@ impl GenericFullOuterJoinProbe { fn inner( left_df: DataFrame, right_df: DataFrame, - suffix: &str, + suffix: PlSmallStr, swapped: bool, - output_names: &mut Option>, + output_names: &mut Option>, ) -> PolarsResult { let (mut left_df, right_df) = if swapped { (right_df, left_df) @@ -127,7 +126,7 @@ impl GenericFullOuterJoinProbe { .iter_mut() .zip(names) .for_each(|(s, name)| { - s.rename(name); + s.rename(name.clone()); }); left_df }, @@ -138,32 +137,24 @@ impl GenericFullOuterJoinProbe { let out = inner( left_df.clone(), right_df, - self.suffix.as_ref(), + self.suffix.clone(), self.swapped, &mut self.output_names, )?; - let l = self - .key_names_left - .iter() - .map(|s| s.as_str()) - .collect::>(); - let r = self - .key_names_right - .iter() - .map(|s| s.as_str()) - .collect::>(); + let l = self.key_names_left.iter().cloned().collect::>(); + let r = self.key_names_right.iter().cloned().collect::>(); Ok(_coalesce_full_join( out, - &l, - &r, - Some(self.suffix.as_ref()), + l.as_slice(), + r.as_slice(), + Some(self.suffix.clone()), &left_df, )) } else { inner( left_df.clone(), right_df, - self.suffix.as_ref(), + self.suffix.clone(), self.swapped, &mut self.output_names, ) @@ -277,7 +268,7 @@ impl GenericFullOuterJoinProbe { right_df .get_columns() .iter() - .map(|s| Series::full_null(s.name(), size, s.dtype())) + .map(|s| Series::full_null(s.name().clone(), size, s.dtype())) .collect(), ) }; diff --git a/crates/polars-pipe/src/executors/sinks/reproject.rs b/crates/polars-pipe/src/executors/sinks/reproject.rs index 8d66e102fd92..bd9553b75f97 100644 --- a/crates/polars-pipe/src/executors/sinks/reproject.rs +++ b/crates/polars-pipe/src/executors/sinks/reproject.rs @@ -40,7 +40,7 @@ impl Sink for ReProjectSink { fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { Ok(match self.sink.finalize(context)? { FinalizedSink::Finished(df) => { - FinalizedSink::Finished(df.select(self.schema.iter_names())?) + FinalizedSink::Finished(df.select(self.schema.iter_names_cloned())?) }, FinalizedSink::Source(source) => { FinalizedSink::Source(Box::new(ReProjectSource::new(self.schema.clone(), source))) diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 5bd51deba54a..b6c5316485b7 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -8,6 +8,7 @@ use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_core::prelude::{AnyValue, SchemaRef, Series, SortOptions}; use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_utils::pl_str::PlSmallStr; use crate::executors::sinks::io::{block_thread_until_io_thread_done, IOThread}; use crate::executors::sinks::memory::MemTracker; @@ -190,7 +191,7 @@ impl Sink for SortSink { let mut lock = self.io_thread.write().unwrap(); let io_thread = lock.take().unwrap(); - let dist = Series::from_any_values("", &self.dist_sample, true).unwrap(); + let dist = Series::from_any_values(PlSmallStr::EMPTY, &self.dist_sample, true).unwrap(); let dist = dist.sort_with(SortOptions::from(&self.sort_options))?; let instant = self.ooc_start.unwrap(); diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index c7256f084aeb..053ccb1f1999 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -98,8 +98,9 @@ fn finalize_dataframe( for (sort_idx, arr) in sort_idx.into_iter().zip(arrays) { let (name, logical_dtype) = schema.get_at_index(sort_idx).unwrap(); - assert_eq!(logical_dtype.to_physical(), DataType::from(arr.data_type())); - let col = Series::from_chunks_and_dtype_unchecked(name, vec![arr], logical_dtype); + assert_eq!(logical_dtype.to_physical(), DataType::from(arr.dtype())); + let col = + Series::from_chunks_and_dtype_unchecked(name.clone(), vec![arr], logical_dtype); cols.insert(sort_idx, col); } } @@ -227,7 +228,7 @@ impl SortSinkMultiple { let rows_encoded = polars_row::convert_columns(&self.sort_column, &self.sort_fields); let column = unsafe { Series::from_chunks_and_dtype_unchecked( - POLARS_SORT_COLUMN, + PlSmallStr::from_static(POLARS_SORT_COLUMN), vec![Box::new(rows_encoded.into_array())], &DataType::BinaryOffset, ) diff --git a/crates/polars-pipe/src/executors/sinks/utils.rs b/crates/polars-pipe/src/executors/sinks/utils.rs index 9f868a52228b..bd8b8bd9e4a6 100644 --- a/crates/polars-pipe/src/executors/sinks/utils.rs +++ b/crates/polars-pipe/src/executors/sinks/utils.rs @@ -1,8 +1,8 @@ use arrow::array::BinaryArray; -use polars_core::export::ahash::RandomState; use polars_core::hashing::_hash_binary_array; +use polars_utils::aliases::PlRandomState; -pub(super) fn hash_rows(columns: &BinaryArray, buf: &mut Vec, hb: &RandomState) { +pub(super) fn hash_rows(columns: &BinaryArray, buf: &mut Vec, hb: &PlRandomState) { debug_assert!(buf.is_empty()); _hash_binary_array(columns, hb.clone(), buf); } diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index 548bd2496334..f3267ac1e90a 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -1,12 +1,13 @@ use std::fs::File; -use std::path::PathBuf; +use polars_core::error::feature_gated; use polars_core::{config, POOL}; use polars_io::csv::read::{BatchedCsvReader, CsvReadOptions, CsvReader}; use polars_io::path_utils::is_cloud_url; use polars_plan::global::_set_n_rows_for_scan; +use polars_plan::plans::ScanSources; use polars_plan::prelude::FileScanOptions; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use super::*; use crate::pipeline::determine_chunk_size; @@ -20,7 +21,7 @@ pub(crate) struct CsvSource { batched_reader: Option>, reader: Option>, n_threads: usize, - paths: Arc>, + sources: ScanSources, options: Option, file_options: FileScanOptions, verbose: bool, @@ -36,6 +37,10 @@ impl CsvSource { // otherwise all files would be opened during construction of the pipeline // leading to Too many Open files error fn init_next_reader(&mut self) -> PolarsResult<()> { + let paths = self + .sources + .as_paths() + .ok_or_else(|| polars_err!(nyi = "Streaming scanning of in-memory buffers"))?; let file_options = self.file_options.clone(); let n_rows = file_options.slice.map(|x| { @@ -43,12 +48,12 @@ impl CsvSource { x.1 }); - if self.current_path_idx == self.paths.len() + if self.current_path_idx == paths.len() || (n_rows.is_some() && n_rows.unwrap() <= self.n_rows_read) { return Ok(()); } - let path = &self.paths[self.current_path_idx]; + let path = &paths[self.current_path_idx]; let force_async = config::force_async(); let run_async = force_async || is_cloud_url(path); @@ -104,8 +109,7 @@ impl CsvSource { .with_row_index(row_index); let reader: CsvReader = if run_async { - #[cfg(feature = "cloud")] - { + feature_gated!("cloud", { options.into_reader_with_file_handle( polars_io::file_cache::FILE_CACHE .get_entry(path.to_str().unwrap()) @@ -113,11 +117,7 @@ impl CsvSource { .unwrap() .try_open_assume_latest()?, ) - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") - } + }) } else { options .with_path(Some(path)) @@ -125,7 +125,8 @@ impl CsvSource { }; if let Some(col) = &file_options.include_file_paths { - self.include_file_path = Some(StringChunked::full(col, path.to_str().unwrap(), 1)); + self.include_file_path = + Some(StringChunked::full(col.clone(), path.to_str().unwrap(), 1)); }; self.reader = Some(reader); @@ -139,7 +140,7 @@ impl CsvSource { } pub(crate) fn new( - paths: Arc>, + sources: ScanSources, schema: SchemaRef, options: CsvReadOptions, file_options: FileScanOptions, @@ -150,7 +151,7 @@ impl CsvSource { reader: None, batched_reader: None, n_threads: POOL.current_num_threads(), - paths, + sources, options: Some(options), file_options, verbose, @@ -211,7 +212,7 @@ impl Source for CsvSource { if let Some(ca) = &mut self.include_file_path { if ca.len() < max_height { - *ca = ca.new_from_index(max_height, 0); + *ca = ca.new_from_index(0, max_height); }; for data_chunk in &mut out { diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index c77910d24adc..8592021b2ff3 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -1,8 +1,10 @@ use std::collections::VecDeque; use std::ops::Range; use std::path::PathBuf; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use futures::{StreamExt, TryStreamExt}; use polars_core::config::{self, get_file_prefetch_size}; use polars_core::error::*; use polars_core::prelude::Series; @@ -18,10 +20,10 @@ use polars_io::prelude::materialize_projection; use polars_io::prelude::ParquetAsyncReader; use polars_io::utils::slice::split_slice_at_file; use polars_io::SerReader; -use polars_plan::plans::FileInfo; +use polars_plan::plans::{FileInfo, ScanSources}; use polars_plan::prelude::hive::HivePartitions; use polars_plan::prelude::FileScanOptions; -use polars_utils::iter::EnumerateIdxTrait; +use polars_utils::itertools::Itertools; use polars_utils::IdxSize; use crate::executors::sources::get_source_index; @@ -32,9 +34,9 @@ pub struct ParquetSource { batched_readers: VecDeque, n_threads: usize, processed_paths: usize, - processed_rows: usize, + processed_rows: AtomicUsize, iter: Range, - paths: Arc>, + sources: ScanSources, options: ParquetOptions, file_options: FileScanOptions, #[allow(dead_code)] @@ -75,7 +77,11 @@ impl ParquetSource { usize, Option>, )> { - let path = &self.paths[index]; + let paths = self + .sources + .as_paths() + .ok_or_else(|| polars_err!(nyi = "Streaming scanning of in-memory buffers"))?; + let path = &paths[index]; let options = self.options; let file_options = self.file_options.clone(); let schema = self.file_info.schema.clone(); @@ -110,11 +116,13 @@ impl ParquetSource { } fn init_reader_sync(&mut self) -> PolarsResult<()> { + use std::sync::atomic::Ordering; + let Some(index) = self.iter.next() else { return Ok(()); }; if let Some(slice) = self.file_options.slice { - if self.processed_rows >= slice.0 as usize + slice.1 { + if self.processed_rows.load(Ordering::Relaxed) >= slice.0 as usize + slice.1 { return Ok(()); } } @@ -147,20 +155,22 @@ impl ParquetSource { ); let n_rows_this_file = reader.num_rows().unwrap(); + let current_row_offset = self + .processed_rows + .fetch_add(n_rows_this_file, Ordering::Relaxed); let slice = file_options.slice.map(|slice| { assert!(slice.0 >= 0); let slice_start = slice.0 as usize; let slice_end = slice_start + slice.1; split_slice_at_file( - &mut self.processed_rows.clone(), + &mut current_row_offset.clone(), n_rows_this_file, slice_start, slice_end, ) }); - self.processed_rows += n_rows_this_file; reader = reader.with_slice(slice); reader.batched(chunk_size)? }; @@ -174,42 +184,64 @@ impl ParquetSource { Ok(()) } + /// This function must NOT be run concurrently if there is a slice (or any operation that + /// requires `self.processed_rows` to be incremented in the correct order), as it does not + /// coordinate to increment the row offset in a properly ordered manner. #[cfg(feature = "async")] async fn init_reader_async(&self, index: usize) -> PolarsResult { + use std::sync::atomic::Ordering; + let metadata = self.metadata.clone(); let predicate = self.predicate.clone(); let cloud_options = self.cloud_options.clone(); let (path, options, file_options, projection, chunk_size, hive_partitions) = self.prepare_init_reader(index)?; - assert_eq!(file_options.slice, None); - let batched_reader = { let uri = path.to_string_lossy(); - ParquetAsyncReader::from_uri(&uri, cloud_options.as_ref(), metadata) - .await? - .with_row_index(file_options.row_index) - .with_projection(projection) - .check_schema( - self.file_info - .reader_schema - .as_ref() - .unwrap() - .as_ref() - .unwrap_left(), - ) - .await? - .with_predicate(predicate.clone()) - .use_statistics(options.use_statistics) - .with_hive_partition_columns(hive_partitions) - .with_include_file_path( - self.file_options - .include_file_paths - .as_ref() - .map(|x| (x.clone(), Arc::from(path.to_str().unwrap()))), + + let mut async_reader = + ParquetAsyncReader::from_uri(&uri, cloud_options.as_ref(), metadata) + .await? + .with_row_index(file_options.row_index) + .with_projection(projection) + .check_schema( + self.file_info + .reader_schema + .as_ref() + .unwrap() + .as_ref() + .unwrap_left(), + ) + .await? + .with_predicate(predicate.clone()) + .use_statistics(options.use_statistics) + .with_hive_partition_columns(hive_partitions) + .with_include_file_path( + self.file_options + .include_file_paths + .as_ref() + .map(|x| (x.clone(), Arc::from(path.to_str().unwrap()))), + ); + + let n_rows_this_file = async_reader.num_rows().await?; + let current_row_offset = self + .processed_rows + .fetch_add(n_rows_this_file, Ordering::Relaxed); + + let slice = file_options.slice.map(|slice| { + assert!(slice.0 >= 0); + let slice_start = slice.0 as usize; + let slice_end = slice_start + slice.1; + split_slice_at_file( + &mut current_row_offset.clone(), + n_rows_this_file, + slice_start, + slice_end, ) - .batched(chunk_size) - .await? + }); + + async_reader.with_slice(slice).batched(chunk_size).await? }; Ok(batched_reader) } @@ -217,7 +249,7 @@ impl ParquetSource { #[allow(unused_variables)] #[allow(clippy::too_many_arguments)] pub(crate) fn new( - paths: Arc>, + sources: ScanSources, options: ParquetOptions, cloud_options: Option, metadata: Option, @@ -227,6 +259,9 @@ impl ParquetSource { verbose: bool, predicate: Option>, ) -> PolarsResult { + let paths = sources + .as_paths() + .ok_or_else(|| polars_err!(nyi = "Streaming scanning of in-memory buffers"))?; let n_threads = POOL.current_num_threads(); let iter = 0..paths.len(); @@ -241,11 +276,11 @@ impl ParquetSource { batched_readers: VecDeque::new(), n_threads, processed_paths: 0, - processed_rows: 0, + processed_rows: AtomicUsize::new(0), options, file_options, iter, - paths, + sources, cloud_options, metadata, file_info, @@ -269,29 +304,36 @@ impl ParquetSource { // // It is important we do this for a reasonable batch size, that's why we start this when we // have just 2 readers left. - if self.file_options.slice.is_none() - && self.run_async - && (self.batched_readers.len() <= 2 || self.batched_readers.is_empty()) - { + if self.run_async { #[cfg(not(feature = "async"))] panic!("activate 'async' feature"); #[cfg(feature = "async")] { - let range = 0..self.prefetch_size - self.batched_readers.len(); - let range = range - .zip(&mut self.iter) - .map(|(_, index)| index) - .collect::>(); - let init_iter = range.into_iter().map(|index| self.init_reader_async(index)); - - let batched_readers = - polars_io::pl_async::get_runtime().block_on_potential_spawn(async { - futures::future::try_join_all(init_iter).await - })?; - - for r in batched_readers { - self.finish_init_reader(r)?; + if self.batched_readers.len() <= 2 || self.batched_readers.is_empty() { + let range = 0..self.prefetch_size - self.batched_readers.len(); + let range = range + .zip(&mut self.iter) + .map(|(_, index)| index) + .collect::>(); + let init_iter = range.into_iter().map(|index| self.init_reader_async(index)); + + let batched_readers = if self.file_options.slice.is_some() { + polars_io::pl_async::get_runtime().block_on_potential_spawn(async { + futures::stream::iter(init_iter) + .then(|x| x) + .try_collect() + .await + })? + } else { + polars_io::pl_async::get_runtime().block_on_potential_spawn(async { + futures::future::try_join_all(init_iter).await + })? + }; + + for r in batched_readers { + self.finish_init_reader(r)?; + } } } } else { diff --git a/crates/polars-pipe/src/operators/chunks.rs b/crates/polars-pipe/src/operators/chunks.rs index 10b89784eaa3..1c78a32dde80 100644 --- a/crates/polars-pipe/src/operators/chunks.rs +++ b/crates/polars-pipe/src/operators/chunks.rs @@ -138,7 +138,7 @@ mod test { .iter() .enumerate() .map(|(i, length)| { - let series = Series::new("val", vec![i as u64; *length]); + let series = Series::new("val".into(), vec![i as u64; *length]); DataFrame::new(vec![series]).unwrap() }) .collect(); diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 799e956c1378..0a6a8946feba 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -67,14 +67,14 @@ where } // projection is free if let Some(schema) = output_schema { - let columns = schema.iter_names().cloned().collect::>(); + let columns = schema.iter_names_cloned().collect::>(); df = df._select_impl_unchecked(&columns)?; } } Ok(Box::new(sources::DataFrameSource::from_df(df)) as Box) }, Scan { - paths, + sources, file_info, hive_parts, file_options, @@ -82,6 +82,8 @@ where output_schema, scan_type, } => { + let paths = sources.into_paths(); + // Add predicate to operators. // Except for parquet, as that format can use statistics to prune file/row-groups. #[cfg(feature = "parquet")] @@ -102,7 +104,7 @@ where #[cfg(feature = "csv")] FileScan::Csv { options, .. } => { let src = sources::CsvSource::new( - paths, + sources, file_info.schema, options, file_options, @@ -131,7 +133,7 @@ where self.p.evaluate_io(df) } - fn live_variables(&self) -> Option>> { + fn live_variables(&self) -> Option> { None } @@ -144,7 +146,7 @@ where }) .transpose()?; let src = sources::ParquetSource::new( - paths, + sources, parquet_options, cloud_options, metadata, @@ -259,7 +261,7 @@ where match &options.args.how { #[cfg(feature = "cross_join")] JoinType::Cross => Box::new(CrossJoin::new( - options.args.suffix().into(), + options.args.suffix().clone(), swapped, node, placeholder, @@ -293,7 +295,7 @@ where let (join_columns_left, join_columns_right) = swap_eval(); Box::new(GenericBuild::<()>::new( - Arc::from(options.args.suffix()), + options.args.suffix().clone(), options.args.clone(), swapped, join_columns_left, @@ -320,7 +322,7 @@ where let (join_columns_left, join_columns_right) = swap_eval(); Box::new(GenericBuild::::new( - Arc::from(options.args.suffix()), + options.args.suffix().clone(), options.args.clone(), swapped, join_columns_left, @@ -390,7 +392,7 @@ where let keys = input_schema .iter_names() .map(|name| { - let name: Arc = Arc::from(name.as_str()); + let name: PlSmallStr = name.clone(); let node = expr_arena.add(AExpr::Column(name.clone())); ExprIR::new(node, OutputName::Alias(name)) }) @@ -404,11 +406,10 @@ where let keys = keys .iter() .map(|key| { - let (_, name, dtype) = input_schema.get_full(key.as_str()).unwrap(); + let (_, name, dtype) = input_schema.get_full(key.as_ref()).unwrap(); group_by_out_schema.with_column(name.clone(), dtype.clone()); - let name: Arc = Arc::from(key.as_str()); - let node = expr_arena.add(AExpr::Column(name.clone())); - ExprIR::new(node, OutputName::Alias(name)) + let node = expr_arena.add(AExpr::Column(key.clone())); + ExprIR::new(node, OutputName::Alias(key.clone())) }) .collect::>(); @@ -422,7 +423,7 @@ where input_schema.get_full(name.as_str()).unwrap(); group_by_out_schema.with_column(name.clone(), dtype.clone()); - let name: Arc = Arc::from(name.as_str()); + let name: PlSmallStr = name.clone(); let col = expr_arena.add(AExpr::Column(name.clone())); let node = match options.keep_strategy { UniqueKeepStrategy::First | UniqueKeepStrategy::Any => { @@ -589,7 +590,7 @@ where let op = match lp_arena.get(node) { SimpleProjection { input, columns, .. } => { let input_schema = lp_arena.get(*input).schema(lp_arena); - let columns = columns.iter_names().cloned().collect(); + let columns = columns.iter_names_cloned().collect(); let op = operators::SimpleProjectionOperator::new(columns, input_schema.into_owned()); Box::new(op) as Box }, diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 3ad50ace7fd8..dd33428c8398 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -26,6 +26,7 @@ ahash = { workspace = true } arrow = { workspace = true } bitflags = { workspace = true } bytemuck = { workspace = true } +bytes = { workspace = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } ciborium = { workspace = true, optional = true } @@ -41,7 +42,6 @@ recursive = { workspace = true } regex = { workspace = true, optional = true } serde = { workspace = true, features = ["rc"], optional = true } serde_json = { workspace = true, optional = true } -smartstring = { workspace = true } strum_macros = { workspace = true } [build-dependencies] @@ -57,6 +57,7 @@ serde = [ "polars-time/serde", "polars-io/serde", "polars-ops/serde", + "polars-utils/serde", "either/serde", ] streaming = [] @@ -67,6 +68,7 @@ ipc = ["polars-io/ipc"] json = ["polars-io/json", "polars-json"] csv = ["polars-io/csv"] temporal = [ + "chrono", "polars-core/temporal", "polars-core/dtype-date", "polars-core/dtype-datetime", @@ -85,7 +87,7 @@ dtype-i16 = ["polars-core/dtype-i16"] dtype-decimal = ["polars-core/dtype-decimal"] dtype-date = ["polars-time/dtype-date", "temporal"] dtype-datetime = ["polars-time/dtype-datetime", "temporal"] -dtype-duration = ["polars-core/dtype-duration", "polars-time/dtype-duration", "temporal"] +dtype-duration = ["polars-core/dtype-duration", "polars-time/dtype-duration", "temporal", "polars-ops/dtype-duration"] dtype-time = ["polars-time/dtype-time", "temporal"] dtype-array = ["polars-core/dtype-array", "polars-ops/dtype-array"] dtype-categorical = ["polars-core/dtype-categorical"] @@ -183,6 +185,7 @@ offset_by = ["polars-time/offset_by"] bigidx = ["polars-core/bigidx"] polars_cloud = ["serde", "ciborium"] +ir_serde = ["serde", "polars-utils/ir_serde"] panic_on_schema = [] diff --git a/crates/polars-plan/README.md b/crates/polars-plan/README.md index 23d78053d6da..59fce1861941 100644 --- a/crates/polars-plan/README.md +++ b/crates/polars-plan/README.md @@ -1,4 +1,4 @@ -# polars-plan- +# polars-plan `polars-plan` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, that provides source code responsible for Polars logical planning. diff --git a/crates/polars-plan/src/client/check.rs b/crates/polars-plan/src/client/check.rs index 99035d72a77c..84189840a3dd 100644 --- a/crates/polars-plan/src/client/check.rs +++ b/crates/polars-plan/src/client/check.rs @@ -1,62 +1,43 @@ use polars_core::error::{polars_err, PolarsResult}; use polars_io::path_utils::is_cloud_url; -use crate::dsl::Expr; use crate::plans::options::SinkType; -use crate::plans::{DslFunction, DslPlan, FileScan, FunctionNode}; +use crate::plans::{DslPlan, FileScan, ScanSources}; /// Assert that the given [`DslPlan`] is eligible to be executed on Polars Cloud. pub(super) fn assert_cloud_eligible(dsl: &DslPlan) -> PolarsResult<()> { - let mut expr_stack = vec![]; for plan_node in dsl.into_iter() { match plan_node { - DslPlan::MapFunction { - function: DslFunction::FunctionNode(function), - .. - } => match function { - FunctionNode::Opaque { .. } => return ineligible_error("contains opaque function"), - #[cfg(feature = "python")] - FunctionNode::OpaquePython { .. } => { - return ineligible_error("contains Python function") - }, - _ => (), - }, #[cfg(feature = "python")] DslPlan::PythonScan { .. } => return ineligible_error("contains Python scan"), - DslPlan::GroupBy { apply: Some(_), .. } => { - return ineligible_error("contains Python function in group by operation") - }, - DslPlan::Scan { paths, .. } - if paths.lock().unwrap().0.iter().any(|p| !is_cloud_url(p)) => - { - return ineligible_error("contains scan of local file system") - }, DslPlan::Scan { - scan_type: FileScan::Anonymous { .. }, - .. - } => return ineligible_error("contains anonymous scan"), + sources, scan_type, .. + } => { + let sources_lock = sources.lock().unwrap(); + match &sources_lock.sources { + ScanSources::Paths(paths) => { + if paths.iter().any(|p| !is_cloud_url(p)) { + return ineligible_error("contains scan of local file system"); + } + }, + ScanSources::Files(_) => { + return ineligible_error("contains scan of opened files"); + }, + ScanSources::Buffers(_) => { + return ineligible_error("contains scan of in-memory buffer"); + }, + } + + if matches!(scan_type, FileScan::Anonymous { .. }) { + return ineligible_error("contains anonymous scan"); + } + }, DslPlan::Sink { payload, .. } => { if !matches!(payload, SinkType::Cloud { .. }) { return ineligible_error("contains sink to non-cloud location"); } }, - plan => { - plan.get_expr(&mut expr_stack); - - for expr in expr_stack.drain(..) { - for expr_node in expr.into_iter() { - match expr_node { - Expr::AnonymousFunction { .. } => { - return ineligible_error("contains anonymous function") - }, - Expr::RenameAlias { .. } => { - return ineligible_error("contains custom name remapping") - }, - _ => (), - } - } - } - }, + _ => (), } } Ok(()) @@ -102,47 +83,6 @@ impl DslPlan { PythonScan { .. } => (), } } - - fn get_expr<'a>(&'a self, scratch: &mut Vec<&'a Expr>) { - use DslPlan::*; - match self { - Filter { predicate, .. } => scratch.push(predicate), - Scan { predicate, .. } => { - if let Some(expr) = predicate { - scratch.push(expr) - } - }, - DataFrameScan { filter, .. } => { - if let Some(expr) = filter { - scratch.push(expr) - } - }, - Select { expr, .. } => scratch.extend(expr), - HStack { exprs, .. } => scratch.extend(exprs), - Sort { by_column, .. } => scratch.extend(by_column), - GroupBy { keys, aggs, .. } => { - scratch.extend(keys); - scratch.extend(aggs); - }, - Join { - left_on, right_on, .. - } => { - scratch.extend(left_on); - scratch.extend(right_on); - }, - Cache { .. } - | Distinct { .. } - | Slice { .. } - | MapFunction { .. } - | Union { .. } - | HConcat { .. } - | ExtContext { .. } - | Sink { .. } - | IR { .. } => (), - #[cfg(feature = "python")] - PythonScan { .. } => (), - } - } } pub struct DslPlanIter<'a> { diff --git a/crates/polars-plan/src/client/mod.rs b/crates/polars-plan/src/client/mod.rs index a815babcc6ad..f5a5cdb0f763 100644 --- a/crates/polars-plan/src/client/mod.rs +++ b/crates/polars-plan/src/client/mod.rs @@ -1,38 +1,18 @@ mod check; -use std::sync::Arc; +use arrow::legacy::error::to_compute_err; +use polars_core::error::PolarsResult; -use polars_core::error::{polars_ensure, polars_err, PolarsResult}; -use polars_io::parquet::write::ParquetWriteOptions; -use polars_io::path_utils::is_cloud_url; - -use crate::plans::options::{FileType, SinkType}; use crate::plans::DslPlan; /// Prepare the given [`DslPlan`] for execution on Polars Cloud. -pub fn prepare_cloud_plan(dsl: DslPlan, uri: String) -> PolarsResult> { +pub fn prepare_cloud_plan(dsl: DslPlan) -> PolarsResult> { // Check the plan for cloud eligibility. check::assert_cloud_eligible(&dsl)?; - // Add Sink node. - polars_ensure!( - is_cloud_url(&uri), - InvalidOperation: "non-cloud paths not supported: {uri}" - ); - let sink_type = SinkType::Cloud { - uri: Arc::new(uri), - file_type: FileType::Parquet(ParquetWriteOptions::default()), - cloud_options: None, - }; - let dsl = DslPlan::Sink { - input: Arc::new(dsl), - payload: sink_type, - }; - // Serialize the plan. let mut writer = Vec::new(); - ciborium::into_writer(&dsl, &mut writer) - .map_err(|err| polars_err!(ComputeError: err.to_string()))?; + ciborium::into_writer(&dsl, &mut writer).map_err(to_compute_err)?; Ok(writer) } diff --git a/crates/polars-plan/src/constants.rs b/crates/polars-plan/src/constants.rs index 2ae0c0e47c47..e63ad1193774 100644 --- a/crates/polars-plan/src/constants.rs +++ b/crates/polars-plan/src/constants.rs @@ -1,22 +1,22 @@ -use std::sync::{Arc, OnceLock}; +use std::sync::OnceLock; -use crate::prelude::ColumnName; +use polars_utils::pl_str::PlSmallStr; pub static MAP_LIST_NAME: &str = "map_list"; pub static CSE_REPLACED: &str = "__POLARS_CSER_"; pub const LEN: &str = "len"; -pub const LITERAL_NAME: &str = "literal"; +const LITERAL_NAME: &str = "literal"; pub const UNLIMITED_CACHE: u32 = u32::MAX; // Cache the often used LITERAL and LEN constants -static LITERAL_NAME_INIT: OnceLock> = OnceLock::new(); -static LEN_INIT: OnceLock> = OnceLock::new(); +static LITERAL_NAME_INIT: OnceLock = OnceLock::new(); +static LEN_INIT: OnceLock = OnceLock::new(); -pub(crate) fn get_literal_name() -> Arc { - LITERAL_NAME_INIT - .get_or_init(|| ColumnName::from(LITERAL_NAME)) - .clone() +pub fn get_literal_name() -> &'static PlSmallStr { + LITERAL_NAME_INIT.get_or_init(|| PlSmallStr::from_static(LITERAL_NAME)) } -pub(crate) fn get_len_name() -> Arc { - LEN_INIT.get_or_init(|| ColumnName::from(LEN)).clone() +pub(crate) fn get_len_name() -> PlSmallStr { + LEN_INIT + .get_or_init(|| PlSmallStr::from_static(LEN)) + .clone() } diff --git a/crates/polars-plan/src/dsl/arithmetic.rs b/crates/polars-plan/src/dsl/arithmetic.rs index c192476cee3b..d155a301f272 100644 --- a/crates/polars-plan/src/dsl/arithmetic.rs +++ b/crates/polars-plan/src/dsl/arithmetic.rs @@ -63,7 +63,7 @@ impl Expr { FunctionExpr::Pow(PowFunction::Generic), &[exponent.into()], false, - false, + None, ) } @@ -122,7 +122,7 @@ impl Expr { /// Compute the inverse tangent of the given expression, with the angle expressed as the argument of a complex number #[cfg(feature = "trigonometry")] pub fn arctan2(self, x: Self) -> Self { - self.map_many_private(FunctionExpr::Atan2, &[x], false, false) + self.map_many_private(FunctionExpr::Atan2, &[x], false, None) } /// Compute the hyperbolic cosine of the given expression diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index e781e65201fe..558a7a98a42a 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -110,7 +110,7 @@ impl ArrayNameSpace { FunctionExpr::ArrayExpr(ArrayFunction::Get(null_on_oob)), &[index], false, - false, + None, ) } @@ -122,7 +122,7 @@ impl ArrayNameSpace { FunctionExpr::ArrayExpr(ArrayFunction::Join(ignore_nulls)), &[separator], false, - false, + None, ) } @@ -135,7 +135,7 @@ impl ArrayNameSpace { FunctionExpr::ArrayExpr(ArrayFunction::Contains), &[other], false, - false, + None, ) } @@ -149,7 +149,7 @@ impl ArrayNameSpace { FunctionExpr::ArrayExpr(ArrayFunction::CountMatches), &[other], false, - false, + None, ) .with_function_options(|mut options| { options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; @@ -174,7 +174,7 @@ impl ArrayNameSpace { let fields = (0..*width) .map(|i| { let name = arr_default_struct_name_gen(i); - Field::from_owned(name, inner.as_ref().clone()) + Field::new(name, inner.as_ref().clone()) }) .collect(); Ok(DataType::Struct(fields)) @@ -189,7 +189,7 @@ impl ArrayNameSpace { FunctionExpr::ArrayExpr(ArrayFunction::Shift), &[n], false, - false, + None, ) } } diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index cf47edff38d1..9091b1777b65 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -9,7 +9,7 @@ impl BinaryNameSpace { FunctionExpr::BinaryExpr(BinaryFunction::Contains), &[pat], false, - true, + Some(Default::default()), ) } @@ -19,7 +19,7 @@ impl BinaryNameSpace { FunctionExpr::BinaryExpr(BinaryFunction::EndsWith), &[sub], false, - true, + Some(Default::default()), ) } @@ -29,7 +29,7 @@ impl BinaryNameSpace { FunctionExpr::BinaryExpr(BinaryFunction::StartsWith), &[sub], false, - true, + Some(Default::default()), ) } diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index 333f8e8e5680..141cff847207 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -21,7 +21,7 @@ impl DateLikeNameSpace { }), &[n], false, - false, + None, ) } @@ -208,7 +208,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::Truncate), &[every], false, - false, + None, ) } @@ -246,7 +246,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::Round), &[every], false, - false, + None, ) } @@ -258,7 +258,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::OffsetBy), &[by], false, - false, + None, ) } @@ -273,7 +273,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::ReplaceTimeZone(time_zone, non_existent)), &[ambiguous], false, - false, + None, ) } @@ -283,7 +283,7 @@ impl DateLikeNameSpace { FunctionExpr::TemporalExpr(TemporalFunction::Combine(tu)), &[time], false, - false, + None, ) } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 37254c4da3aa..ced2de5e7eb5 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -61,16 +61,18 @@ impl AsRef for AggExpr { } } -/// Expressions that can be used in various contexts. Queries consist of multiple expressions. When using the polars -/// lazy API, don't construct an `Expr` directly; instead, create one using the functions in the `polars_lazy::dsl` -/// module. See that module's docs for more info. +/// Expressions that can be used in various contexts. +/// +/// Queries consist of multiple expressions. +/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using +/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info. #[derive(Clone, PartialEq)] #[must_use] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Expr { - Alias(Arc, ColumnName), - Column(ColumnName), - Columns(Arc<[ColumnName]>), + Alias(Arc, PlSmallStr), + Column(PlSmallStr), + Columns(Arc<[PlSmallStr]>), DtypeColumn(Vec), IndexColumn(Arc<[i64]>), Literal(LiteralValue), @@ -81,7 +83,7 @@ pub enum Expr { }, Cast { expr: Arc, - data_type: DataType, + dtype: DataType, options: CastOptions, }, Sort { @@ -134,27 +136,25 @@ pub enum Expr { length: Arc, }, /// Can be used in a select statement to exclude a column from selection + /// TODO: See if we can replace `Vec` with `Arc` Exclude(Arc, Vec), /// Set root name as Alias KeepName(Arc), Len, /// Take the nth column in the `DataFrame` Nth(i64), - // skipped fields must be last otherwise serde fails in pickle - #[cfg_attr(feature = "serde", serde(skip))] RenameAlias { function: SpecialEq>, expr: Arc, }, #[cfg(feature = "dtype-struct")] - Field(Arc<[ColumnName]>), + Field(Arc<[PlSmallStr]>), AnonymousFunction { /// function arguments input: Vec, /// function to apply function: SpecialEq>, /// output dtype of the function - #[cfg_attr(feature = "serde", serde(skip))] output_type: GetOutput, options: FunctionOptions, }, @@ -192,11 +192,11 @@ impl Hash for Expr { }, Expr::Cast { expr, - data_type, + dtype, options: strict, } => { expr.hash(state); - data_type.hash(state); + dtype.hash(state); strict.hash(state) }, Expr::Sort { expr, options } => { @@ -301,7 +301,7 @@ impl Default for Expr { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Excluded { - Name(ColumnName), + Name(PlSmallStr), Dtype(DataType), } @@ -318,7 +318,7 @@ impl Expr { ctxt: Context, expr_arena: &mut Arena, ) -> PolarsResult { - let root = to_aexpr(self.clone(), expr_arena); + let root = to_aexpr(self.clone(), expr_arena)?; expr_arena.get(root).to_field(schema, ctxt, expr_arena) } } @@ -376,7 +376,7 @@ impl Display for Operator { } impl Operator { - pub(crate) fn is_comparison(&self) -> bool { + pub fn is_comparison(&self) -> bool { matches!( self, Self::Eq diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index d2593d0e3bcb..9ac6f872eed8 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -1,5 +1,6 @@ use std::fmt::Formatter; use std::ops::Deref; +use std::sync::Arc; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -17,7 +18,7 @@ pub trait SeriesUdf: Send + Sync { fn call_udf(&self, s: &mut [Series]) -> PolarsResult>; fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { - polars_bail!(ComputeError: "serialize not supported for this 'opaque' function") + polars_bail!(ComputeError: "serialization not supported for this 'opaque' function") } // Needed for python functions. After they are deserialized we first check if they @@ -46,30 +47,29 @@ impl Serialize for SpecialEq> { #[cfg(feature = "serde")] impl<'a> Deserialize<'a> for SpecialEq> { - fn deserialize(_deserializer: D) -> std::result::Result + fn deserialize(deserializer: D) -> std::result::Result where D: Deserializer<'a>, { use serde::de::Error; #[cfg(feature = "python")] { - use crate::dsl::python_udf::MAGIC_BYTE_MARK; - let buf = Vec::::deserialize(_deserializer)?; + let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(udf)) } else { Err(D::Error::custom( - "deserialize not supported for this 'opaque' function", + "deserialization not supported for this 'opaque' function", )) } } #[cfg(not(feature = "python"))] { Err(D::Error::custom( - "deserialize not supported for this 'opaque' function", + "deserialization not supported for this 'opaque' function", )) } } @@ -124,11 +124,18 @@ impl Default for SpecialEq> { } pub trait RenameAliasFn: Send + Sync { - fn call(&self, name: &str) -> PolarsResult; + fn call(&self, name: &PlSmallStr) -> PolarsResult; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this renaming function") + } } -impl PolarsResult + Send + Sync> RenameAliasFn for F { - fn call(&self, name: &str) -> PolarsResult { +impl RenameAliasFn for F +where + F: Fn(&PlSmallStr) -> PolarsResult + Send + Sync, +{ + fn call(&self, name: &PlSmallStr) -> PolarsResult { self(name) } } @@ -250,6 +257,10 @@ pub trait FunctionOutputField: Send + Sync { cntxt: Context, fields: &[Field], ) -> PolarsResult; + + fn try_serialize(&self, _buf: &mut Vec) -> PolarsResult<()> { + polars_bail!(ComputeError: "serialization not supported for this output field") + } } pub type GetOutput = SpecialEq>; @@ -269,7 +280,7 @@ impl GetOutput { pub fn from_type(dt: DataType) -> Self { SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { - Ok(Field::new(flds[0].name(), dt.clone())) + Ok(Field::new(flds[0].name().clone(), dt.clone())) })) } @@ -292,7 +303,7 @@ impl GetOutput { ) -> Self { SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { let mut fld = flds[0].clone(); - let new_type = f(fld.data_type())?; + let new_type = f(fld.dtype())?; fld.coerce(new_type); Ok(fld) })) @@ -323,7 +334,7 @@ impl GetOutput { { SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { let mut fld = flds[0].clone(); - let dtypes = flds.iter().map(|fld| fld.data_type()).collect::>(); + let dtypes = flds.iter().map(|fld| fld.dtype()).collect::>(); let new_type = f(&dtypes)?; fld.coerce(new_type); Ok(fld) @@ -344,3 +355,76 @@ where self(input_schema, cntxt, fields) } } + +#[cfg(feature = "serde")] +impl Serialize for GetOutput { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for GetOutput { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + #[cfg(feature = "python")] + { + let buf = Vec::::deserialize(deserializer)?; + + if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) + .map_err(|e| D::Error::custom(format!("{e}")))?; + Ok(SpecialEq::new(get_output)) + } else { + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + } + #[cfg(not(feature = "python"))] + { + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + } +} + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + use serde::ser::Error; + let mut buf = vec![]; + self.0 + .try_serialize(&mut buf) + .map_err(|e| S::Error::custom(format!("{e}")))?; + serializer.serialize_bytes(&buf) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for SpecialEq> { + fn deserialize(_deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + Err(D::Error::custom( + "deserialization not supported for this renaming function", + )) + } +} diff --git a/crates/polars-plan/src/dsl/from.rs b/crates/polars-plan/src/dsl/from.rs index eeaa631521cb..dcc53f51e1f9 100644 --- a/crates/polars-plan/src/dsl/from.rs +++ b/crates/polars-plan/src/dsl/from.rs @@ -8,7 +8,7 @@ impl From for Expr { impl From<&str> for Expr { fn from(s: &str) -> Self { - col(s) + col(PlSmallStr::from_str(s)) } } diff --git a/crates/polars-plan/src/dsl/function_expr/arg_where.rs b/crates/polars-plan/src/dsl/function_expr/arg_where.rs index 74bafa243e00..8f77be0724bd 100644 --- a/crates/polars-plan/src/dsl/function_expr/arg_where.rs +++ b/crates/polars-plan/src/dsl/function_expr/arg_where.rs @@ -6,7 +6,11 @@ pub(super) fn arg_where(s: &mut [Series]) -> PolarsResult> { let predicate = s[0].bool()?; if predicate.is_empty() { - Ok(Some(Series::full_null(predicate.name(), 0, &IDX_DTYPE))) + Ok(Some(Series::full_null( + predicate.name().clone(), + 0, + &IDX_DTYPE, + ))) } else { let capacity = predicate.sum().unwrap(); let mut out = Vec::with_capacity(capacity as usize); @@ -32,7 +36,7 @@ pub(super) fn arg_where(s: &mut [Series]) -> PolarsResult> { total_offset += arr.len(); }); - let ca = IdxCa::with_chunk(predicate.name(), IdxArr::from_vec(out)); + let ca = IdxCa::with_chunk(predicate.name().clone(), IdxArr::from_vec(out)); Ok(Some(ca.into_series())) } } diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index ece457aa7143..0de5e9d99883 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -221,7 +221,9 @@ pub(super) fn contains(s: &[Series]) -> PolarsResult { polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)), SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(), ); - Ok(is_in(item, array)?.with_name(array.name()).into_series()) + Ok(is_in(item, array)? + .with_name(array.name().clone()) + .into_series()) } #[cfg(feature = "array_count")] diff --git a/crates/polars-plan/src/dsl/function_expr/binary.rs b/crates/polars-plan/src/dsl/function_expr/binary.rs index 3a2525fec060..f803ba0ba952 100644 --- a/crates/polars-plan/src/dsl/function_expr/binary.rs +++ b/crates/polars-plan/src/dsl/function_expr/binary.rs @@ -86,7 +86,10 @@ impl From for SpecialEq> { pub(super) fn contains(s: &[Series]) -> PolarsResult { let ca = s[0].binary()?; let lit = s[1].binary()?; - Ok(ca.contains_chunked(lit).with_name(ca.name()).into_series()) + Ok(ca + .contains_chunked(lit) + .with_name(ca.name().clone()) + .into_series()) } pub(super) fn ends_with(s: &[Series]) -> PolarsResult { @@ -95,7 +98,7 @@ pub(super) fn ends_with(s: &[Series]) -> PolarsResult { Ok(ca .ends_with_chunked(suffix) - .with_name(ca.name()) + .with_name(ca.name().clone()) .into_series()) } @@ -105,7 +108,7 @@ pub(super) fn starts_with(s: &[Series]) -> PolarsResult { Ok(ca .starts_with_chunked(prefix) - .with_name(ca.name()) + .with_name(ca.name().clone()) .into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs index d77da88f69a7..d00045c0d3f9 100644 --- a/crates/polars-plan/src/dsl/function_expr/boolean.rs +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -133,18 +133,18 @@ impl From for FunctionExpr { fn any(s: &Series, ignore_nulls: bool) -> PolarsResult { let ca = s.bool()?; if ignore_nulls { - Ok(Series::new(s.name(), [ca.any()])) + Ok(Series::new(s.name().clone(), [ca.any()])) } else { - Ok(Series::new(s.name(), [ca.any_kleene()])) + Ok(Series::new(s.name().clone(), [ca.any_kleene()])) } } fn all(s: &Series, ignore_nulls: bool) -> PolarsResult { let ca = s.bool()?; if ignore_nulls { - Ok(Series::new(s.name(), [ca.all()])) + Ok(Series::new(s.name().clone(), [ca.all()])) } else { - Ok(Series::new(s.name(), [ca.all_kleene()])) + Ok(Series::new(s.name().clone(), [ca.all_kleene()])) } } @@ -217,16 +217,19 @@ fn any_horizontal(s: &[Series]) -> PolarsResult { .install(|| { s.par_iter() .try_fold( - || BooleanChunked::new("", &[false]), + || BooleanChunked::new(PlSmallStr::EMPTY, &[false]), |acc, b| { let b = b.cast(&DataType::Boolean)?; let b = b.bool()?; PolarsResult::Ok((&acc).bitor(b)) }, ) - .try_reduce(|| BooleanChunked::new("", [false]), |a, b| Ok(a.bitor(b))) + .try_reduce( + || BooleanChunked::new(PlSmallStr::EMPTY, [false]), + |a, b| Ok(a.bitor(b)), + ) })? - .with_name(s[0].name()); + .with_name(s[0].name().clone()); Ok(out.into_series()) } @@ -236,15 +239,18 @@ fn all_horizontal(s: &[Series]) -> PolarsResult { .install(|| { s.par_iter() .try_fold( - || BooleanChunked::new("", &[true]), + || BooleanChunked::new(PlSmallStr::EMPTY, &[true]), |acc, b| { let b = b.cast(&DataType::Boolean)?; let b = b.bool()?; PolarsResult::Ok((&acc).bitand(b)) }, ) - .try_reduce(|| BooleanChunked::new("", [true]), |a, b| Ok(a.bitand(b))) + .try_reduce( + || BooleanChunked::new(PlSmallStr::EMPTY, [true]), + |a, b| Ok(a.bitand(b)), + ) })? - .with_name(s[0].name()); + .with_name(s[0].name().clone()); Ok(out.into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/bounds.rs b/crates/polars-plan/src/dsl/function_expr/bounds.rs index 7dcce34e2a71..0f14feb5675f 100644 --- a/crates/polars-plan/src/dsl/function_expr/bounds.rs +++ b/crates/polars-plan/src/dsl/function_expr/bounds.rs @@ -1,7 +1,7 @@ use super::*; pub(super) fn upper_bound(s: &Series) -> PolarsResult { - let name = s.name(); + let name = s.name().clone(); use DataType::*; let s = match s.dtype().to_physical() { #[cfg(feature = "dtype-i8")] @@ -26,7 +26,7 @@ pub(super) fn upper_bound(s: &Series) -> PolarsResult { } pub(super) fn lower_bound(s: &Series) -> PolarsResult { - let name = s.name(); + let name = s.name().clone(); use DataType::*; let s = match s.dtype().to_physical() { #[cfg(feature = "dtype-i8")] diff --git a/crates/polars-plan/src/dsl/function_expr/cat.rs b/crates/polars-plan/src/dsl/function_expr/cat.rs index db50f4ef4429..9cc5d993a638 100644 --- a/crates/polars-plan/src/dsl/function_expr/cat.rs +++ b/crates/polars-plan/src/dsl/function_expr/cat.rs @@ -46,5 +46,5 @@ fn get_categories(s: &Series) -> PolarsResult { let ca = s.categorical()?; let rev_map = ca.get_rev_map(); let arr = rev_map.get_categories().clone().boxed(); - Series::try_from((ca.name(), arr)) + Series::try_from((ca.name().clone(), arr)) } diff --git a/crates/polars-plan/src/dsl/function_expr/coerce.rs b/crates/polars-plan/src/dsl/function_expr/coerce.rs index b131229b5f44..652866491edb 100644 --- a/crates/polars-plan/src/dsl/function_expr/coerce.rs +++ b/crates/polars-plan/src/dsl/function_expr/coerce.rs @@ -1,5 +1,5 @@ use polars_core::prelude::*; pub fn as_struct(s: &[Series]) -> PolarsResult { - Ok(StructChunked::from_series(s[0].name(), s)?.into_series()) + Ok(StructChunked::from_series(s[0].name().clone(), s)?.into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/correlation.rs b/crates/polars-plan/src/dsl/function_expr/correlation.rs index 1510d5145fc1..216a635ba475 100644 --- a/crates/polars-plan/src/dsl/function_expr/correlation.rs +++ b/crates/polars-plan/src/dsl/function_expr/correlation.rs @@ -39,7 +39,7 @@ pub(super) fn corr(s: &[Series], ddof: u8, method: CorrelationMethod) -> PolarsR fn covariance(s: &[Series], ddof: u8) -> PolarsResult { let a = &s[0]; let b = &s[1]; - let name = "cov"; + let name = PlSmallStr::from_static("cov"); use polars_ops::chunked_array::cov::cov; let ret = match a.dtype() { @@ -64,13 +64,13 @@ fn covariance(s: &[Series], ddof: u8) -> PolarsResult { fn pearson_corr(s: &[Series], ddof: u8) -> PolarsResult { let a = &s[0]; let b = &s[1]; - let name = "pearson_corr"; + let name = PlSmallStr::from_static("pearson_corr"); use polars_ops::chunked_array::cov::pearson_corr; let ret = match a.dtype() { DataType::Float32 => { let ret = pearson_corr(a.f32().unwrap(), b.f32().unwrap(), ddof).map(|v| v as f32); - return Ok(Series::new(name, &[ret])); + return Ok(Series::new(name.clone(), &[ret])); }, DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap(), ddof), DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap(), ddof), @@ -94,10 +94,10 @@ fn spearman_rank_corr(s: &[Series], ddof: u8, propagate_nans: bool) -> PolarsRes let (a, b) = coalesce_nulls_series(a, b); - let name = "spearman_rank_correlation"; + let name = PlSmallStr::from_static("spearman_rank_correlation"); if propagate_nans && a.dtype().is_float() { for s in [&a, &b] { - if nan_max_s(s, "") + if nan_max_s(s, PlSmallStr::EMPTY) .get(0) .unwrap() .extract::() diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs index 1a6251a43f82..1d1d6a5022e4 100644 --- a/crates/polars-plan/src/dsl/function_expr/datetime.rs +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -1,4 +1,3 @@ -use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS, SECONDS_IN_DAY}; #[cfg(feature = "timezones")] use chrono_tz::Tz; #[cfg(feature = "timezones")] @@ -124,7 +123,7 @@ impl TemporalFunction { time_unit, time_zone, } => Ok(Field::new( - "datetime", + PlSmallStr::from_static("datetime"), DataType::Datetime(*time_unit, time_zone.clone()), )), Combine(tu) => mapper.try_map_dtype(|dt| match dt { @@ -312,24 +311,31 @@ pub(super) fn microsecond(s: &Series) -> PolarsResult { pub(super) fn nanosecond(s: &Series) -> PolarsResult { s.nanosecond().map(|ca| ca.into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_days(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.days().into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_hours(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.hours().into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_minutes(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.minutes().into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_seconds(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.seconds().into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_milliseconds(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.milliseconds().into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_microseconds(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.microseconds().into_series()) } +#[cfg(feature = "dtype-duration")] pub(super) fn total_nanoseconds(s: &Series) -> PolarsResult { s.duration().map(|ca| ca.nanoseconds().into_series()) } @@ -501,90 +507,3 @@ pub(super) fn round(s: &[Series]) -> PolarsResult { dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), }) } - -pub(super) fn duration(s: &[Series], time_unit: TimeUnit) -> PolarsResult { - if s.iter().any(|s| s.is_empty()) { - return Ok(Series::new_empty( - s[0].name(), - &DataType::Duration(time_unit), - )); - } - - // TODO: Handle overflow for UInt64 - let weeks = s[0].cast(&DataType::Int64).unwrap(); - let days = s[1].cast(&DataType::Int64).unwrap(); - let hours = s[2].cast(&DataType::Int64).unwrap(); - let minutes = s[3].cast(&DataType::Int64).unwrap(); - let seconds = s[4].cast(&DataType::Int64).unwrap(); - let mut milliseconds = s[5].cast(&DataType::Int64).unwrap(); - let mut microseconds = s[6].cast(&DataType::Int64).unwrap(); - let mut nanoseconds = s[7].cast(&DataType::Int64).unwrap(); - - let is_scalar = |s: &Series| s.len() == 1; - let is_zero_scalar = |s: &Series| is_scalar(s) && s.get(0).unwrap() == AnyValue::Int64(0); - - // Process subseconds - let max_len = s.iter().map(|s| s.len()).max().unwrap(); - let mut duration = match time_unit { - TimeUnit::Microseconds => { - if is_scalar(µseconds) { - microseconds = microseconds.new_from_index(0, max_len); - } - if !is_zero_scalar(&nanoseconds) { - microseconds = (microseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000)))?; - } - if !is_zero_scalar(&milliseconds) { - microseconds = (microseconds + (milliseconds * 1_000))?; - } - microseconds - }, - TimeUnit::Nanoseconds => { - if is_scalar(&nanoseconds) { - nanoseconds = nanoseconds.new_from_index(0, max_len); - } - if !is_zero_scalar(µseconds) { - nanoseconds = (nanoseconds + (microseconds * 1_000))?; - } - if !is_zero_scalar(&milliseconds) { - nanoseconds = (nanoseconds + (milliseconds * 1_000_000))?; - } - nanoseconds - }, - TimeUnit::Milliseconds => { - if is_scalar(&milliseconds) { - milliseconds = milliseconds.new_from_index(0, max_len); - } - if !is_zero_scalar(&nanoseconds) { - milliseconds = (milliseconds + (nanoseconds.wrapping_trunc_div_scalar(1_000_000)))?; - } - if !is_zero_scalar(µseconds) { - milliseconds = (milliseconds + (microseconds.wrapping_trunc_div_scalar(1_000)))?; - } - milliseconds - }, - }; - - // Process other duration specifiers - let multiplier = match time_unit { - TimeUnit::Nanoseconds => NANOSECONDS, - TimeUnit::Microseconds => MICROSECONDS, - TimeUnit::Milliseconds => MILLISECONDS, - }; - if !is_zero_scalar(&seconds) { - duration = (duration + seconds * multiplier)?; - } - if !is_zero_scalar(&minutes) { - duration = (duration + minutes * (multiplier * 60))?; - } - if !is_zero_scalar(&hours) { - duration = (duration + hours * (multiplier * 60 * 60))?; - } - if !is_zero_scalar(&days) { - duration = (duration + days * (multiplier * SECONDS_IN_DAY))?; - } - if !is_zero_scalar(&weeks) { - duration = (duration + weeks * (multiplier * SECONDS_IN_DAY * 7))?; - } - - duration.cast(&DataType::Duration(time_unit)) -} diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index cd82ae4251d8..12275fc57200 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -58,11 +58,11 @@ pub(super) fn value_counts( s: &Series, sort: bool, parallel: bool, - name: String, + name: PlSmallStr, normalize: bool, ) -> PolarsResult { s.value_counts(sort, parallel, name, normalize) - .map(|df| df.into_struct(s.name()).into_series()) + .map(|df| df.into_struct(s.name().clone()).into_series()) } #[cfg(feature = "unique_counts")] @@ -121,13 +121,14 @@ pub(super) fn mode(s: &Series) -> PolarsResult { #[cfg(feature = "moment")] pub(super) fn skew(s: &Series, bias: bool) -> PolarsResult { - s.skew(bias).map(|opt_v| Series::new(s.name(), &[opt_v])) + s.skew(bias) + .map(|opt_v| Series::new(s.name().clone(), &[opt_v])) } #[cfg(feature = "moment")] pub(super) fn kurtosis(s: &Series, fisher: bool, bias: bool) -> PolarsResult { s.kurtosis(fisher, bias) - .map(|opt_v| Series::new(s.name(), &[opt_v])) + .map(|opt_v| Series::new(s.name().clone(), &[opt_v])) } pub(super) fn arg_unique(s: &Series) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/fill_null.rs b/crates/polars-plan/src/dsl/function_expr/fill_null.rs index d5e408c0082d..f4d89f203226 100644 --- a/crates/polars-plan/src/dsl/function_expr/fill_null.rs +++ b/crates/polars-plan/src/dsl/function_expr/fill_null.rs @@ -28,7 +28,7 @@ pub(super) fn fill_null(s: &[Series]) -> PolarsResult { let cats = series.to_physical_repr(); let mask = cats.is_not_null(); let out = cats - .zip_with_same_type(&mask, &Series::new("", &[idx])) + .zip_with_same_type(&mask, &Series::new(PlSmallStr::EMPTY, &[idx])) .unwrap(); unsafe { return out.cast_unchecked(series.dtype()) } } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index e68b080d17f1..05df577ed8f3 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -247,7 +247,7 @@ pub(super) fn contains(args: &mut [Series]) -> PolarsResult> { SchemaMismatch: "invalid series dtype: expected `List`, got `{}`", list.dtype(), ); polars_ops::prelude::is_in(item, list).map(|mut ca| { - ca.rename(list.name()); + ca.rename(list.name().clone()); Some(ca.into_series()) }) } @@ -378,7 +378,7 @@ pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { .collect_trusted() }, }; - out.rename(s.name()); + out.rename(s.name().clone()); Ok(Some(out.into_series())) } @@ -417,7 +417,7 @@ pub(super) fn get(s: &mut [Series], null_on_oob: bool) -> PolarsResult PolarsResult>()? }; - let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); + let s = Series::try_from((ca.name().clone(), arr.values().clone())).unwrap(); unsafe { s.take_unchecked(&take_by) } .cast(ca.inner_dtype()) .map(Some) @@ -599,13 +599,13 @@ pub(super) fn set_operation(s: &[Series], set_type: SetOperation) -> PolarsResul if s0.len() == 0 { Ok(s0.clone()) } else { - Ok(s1.clone().with_name(s0.name())) + Ok(s1.clone().with_name(s0.name().clone())) } }, SetOperation::Difference => Ok(s0.clone()), SetOperation::Union | SetOperation::SymmetricDifference => { if s0.len() == 0 { - Ok(s1.clone().with_name(s0.name())) + Ok(s1.clone().with_name(s0.name().clone())) } else { Ok(s0.clone()) } diff --git a/crates/polars-plan/src/dsl/function_expr/log.rs b/crates/polars-plan/src/dsl/function_expr/log.rs index 8793f9614a77..42c71c681f33 100644 --- a/crates/polars-plan/src/dsl/function_expr/log.rs +++ b/crates/polars-plan/src/dsl/function_expr/log.rs @@ -4,9 +4,9 @@ pub(super) fn entropy(s: &Series, base: f64, normalize: bool) -> PolarsResult, - labels: Option>, + labels: Option>, left_closed: bool, include_breaks: bool, }, #[cfg(feature = "cutqcut")] QCut { probs: Vec, - labels: Option>, + labels: Option>, left_closed: bool, allow_duplicates: bool, include_breaks: bool, @@ -307,9 +307,9 @@ pub enum FunctionExpr { /// This will lead to calls over FFI. FfiPlugin { /// Shared library. - lib: Arc, + lib: PlSmallStr, /// Identifier in the shared lib. - symbol: Arc, + symbol: PlSmallStr, /// Pickle serialized keyword arguments. kwargs: Arc<[u8]>, }, @@ -879,7 +879,10 @@ impl From for SpecialEq> { NullCount => { let f = |s: &mut [Series]| { let s = &s[0]; - Ok(Some(Series::new(s.name(), [s.null_count() as IdxSize]))) + Ok(Some(Series::new( + s.name().clone(), + [s.null_count() as IdxSize], + ))) }; wrap!(f) }, diff --git a/crates/polars-plan/src/dsl/function_expr/pow.rs b/crates/polars-plan/src/dsl/function_expr/pow.rs index a9bacae5ae84..5336220d1ace 100644 --- a/crates/polars-plan/src/dsl/function_expr/pow.rs +++ b/crates/polars-plan/src/dsl/function_expr/pow.rs @@ -37,12 +37,15 @@ where ChunkedArray: IntoSeries, { if (base.len() == 1) && (exponent.len() != 1) { + let name = base.name(); let base = base .get(0) .ok_or_else(|| polars_err!(ComputeError: "base is null"))?; Ok(Some( - unary_elementwise_values(exponent, |exp| Pow::pow(base, exp)).into_series(), + unary_elementwise_values(exponent, |exp| Pow::pow(base, exp)) + .into_series() + .with_name(name.clone()), )) } else { Ok(Some( @@ -65,7 +68,11 @@ where if exponent.len() == 1 { let Some(exponent_value) = exponent.get(0) else { - return Ok(Some(Series::full_null(base.name(), base.len(), &dtype))); + return Ok(Some(Series::full_null( + base.name().clone(), + base.len(), + &dtype, + ))); }; let s = match exponent_value.to_f64().unwrap() { a if a == 1.0 => base.clone().into_series(), @@ -104,7 +111,11 @@ where if exponent.len() == 1 { let Some(exponent_value) = exponent.get(0) else { - return Ok(Some(Series::full_null(base.name(), base.len(), &dtype))); + return Ok(Some(Series::full_null( + base.name().clone(), + base.len(), + &dtype, + ))); }; let s = match exponent_value.to_u64().unwrap() { 1 => base.clone().into_series(), diff --git a/crates/polars-plan/src/dsl/function_expr/random.rs b/crates/polars-plan/src/dsl/function_expr/random.rs index 1719e42a2feb..cb21e08367aa 100644 --- a/crates/polars-plan/src/dsl/function_expr/random.rs +++ b/crates/polars-plan/src/dsl/function_expr/random.rs @@ -46,7 +46,7 @@ pub(super) fn sample_frac( match frac.get(0) { Some(frac) => src.sample_frac(frac, with_replacement, shuffle, seed), - None => Ok(Series::new_empty(src.name(), src.dtype())), + None => Ok(Series::new_empty(src.name().clone(), src.dtype())), } } @@ -69,6 +69,6 @@ pub(super) fn sample_n( match n.get(0) { Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed), - None => Ok(Series::new_empty(src.name(), src.dtype())), + None => Ok(Series::new_empty(src.name().clone(), src.dtype())), } } diff --git a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs index bef4946e5729..5518d32df275 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs @@ -25,7 +25,7 @@ pub(super) fn date_range( ComputeError: "`interval` input for `date_range` must consist of full days, got: {interval}" ); - let name = start.name(); + let name = start.name().clone(); let start = temporal_series_to_i64_scalar(&start) .ok_or_else(|| polars_err!(ComputeError: "start is an out-of-range time."))? * MILLISECONDS_IN_DAY; @@ -67,7 +67,7 @@ pub(super) fn date_ranges( let end = end.i64().unwrap() * MILLISECONDS_IN_DAY; let mut builder = ListPrimitiveChunkedBuilder::::new( - start.name(), + start.name().clone(), start.len(), start.len() * CAPACITY_FACTOR, DataType::Int32, @@ -75,7 +75,7 @@ pub(super) fn date_ranges( let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { let rng = datetime_range_impl( - "", + PlSmallStr::EMPTY, start, end, interval, diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs index e046b94b03a7..394889dd34f1 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -95,7 +95,7 @@ pub(super) fn datetime_range( Some(tz) => Some(parse_time_zone(tz)?), _ => None, }; - datetime_range_impl(name, start, end, interval, closed, tu, tz.as_ref())? + datetime_range_impl(name.clone(), start, end, interval, closed, tu, tz.as_ref())? }, _ => unimplemented!(), }; @@ -189,7 +189,7 @@ pub(super) fn datetime_ranges( let out = match dtype { DataType::Datetime(tu, ref tz) => { let mut builder = ListPrimitiveChunkedBuilder::::new( - start.name(), + start.name().clone(), start.len(), start.len() * CAPACITY_FACTOR, DataType::Int64, @@ -201,7 +201,15 @@ pub(super) fn datetime_ranges( _ => None, }; let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { - let rng = datetime_range_impl("", start, end, interval, closed, tu, tz.as_ref())?; + let rng = datetime_range_impl( + PlSmallStr::EMPTY, + start, + end, + interval, + closed, + tu, + tz.as_ref(), + )?; builder.append_slice(rng.cont_slice().unwrap()); Ok(()) }; @@ -219,7 +227,7 @@ impl<'a> FieldsMapper<'a> { pub(super) fn map_to_datetime_range_dtype( &self, time_unit: Option<&TimeUnit>, - time_zone: Option<&str>, + time_zone: Option<&PlSmallStr>, ) -> PolarsResult { let data_dtype = self.map_to_supertype()?.dtype; @@ -233,10 +241,7 @@ impl<'a> FieldsMapper<'a> { Some(tu) => *tu, None => data_tu, }; - let tz = match time_zone { - Some(tz) => Some(tz.to_string()), - None => data_tz, - }; + let tz = time_zone.cloned().or(data_tz); Ok(DataType::Datetime(tu, tz)) } diff --git a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs index 5344ec0b5ee8..f1ae0ffe13a7 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -27,7 +27,7 @@ pub(super) fn int_range(s: &[Series], step: i64, dtype: DataType) -> PolarsResul with_match_physical_integer_polars_type!(dtype, |$T| { let start_v = get_first_series_value::<$T>(start)?; let end_v = get_first_series_value::<$T>(end)?; - new_int_range::<$T>(start_v, end_v, step, name) + new_int_range::<$T>(start_v, end_v, step, name.clone()) }) } @@ -58,7 +58,7 @@ pub(super) fn int_ranges(s: &[Series]) -> PolarsResult { let len = std::cmp::max(start.len(), end.len()); let mut builder = ListPrimitiveChunkedBuilder::::new( // The name should follow our left hand rule. - start.name(), + start.name().clone(), len, len * CAPACITY_FACTOR, DataType::Int64, diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs index b13d45bdd73c..3350f0c6f8f5 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -83,7 +83,7 @@ impl RangeFunction { } => { // output dtype may change based on `interval`, `time_unit`, and `time_zone` let dtype = - mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_deref())?; + mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_ref())?; mapper.with_dtype(dtype) }, #[cfg(feature = "dtype-datetime")] @@ -95,7 +95,7 @@ impl RangeFunction { } => { // output dtype may change based on `interval`, `time_unit`, and `time_zone` let inner_dtype = - mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_deref())?; + mapper.map_to_datetime_range_dtype(time_unit.as_ref(), time_zone.as_ref())?; mapper.with_dtype(DataType::List(Box::new(inner_dtype))) }, #[cfg(feature = "dtype-time")] diff --git a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs index 991368356cc5..52211e89bc56 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs @@ -25,7 +25,7 @@ pub(super) fn time_range( let end = temporal_series_to_i64_scalar(&end.cast(&dtype)?) .ok_or_else(|| polars_err!(ComputeError: "end is an out-of-range time."))?; - let out = time_range_impl(name, start, end, interval, closed)?; + let out = time_range_impl(name.clone(), start, end, interval, closed)?; Ok(out.cast(&dtype).unwrap().into_series()) } @@ -47,14 +47,14 @@ pub(super) fn time_ranges( let len = std::cmp::max(start.len(), end.len()); let mut builder = ListPrimitiveChunkedBuilder::::new( - start.name(), + start.name().clone(), len, len * CAPACITY_FACTOR, DataType::Int64, ); let range_impl = |start, end, builder: &mut ListPrimitiveChunkedBuilder| { - let rng = time_range_impl("", start, end, interval, closed)?; + let rng = time_range_impl(PlSmallStr::EMPTY, start, end, interval, closed)?; builder.append_slice(rng.cont_slice().unwrap()); Ok(()) }; diff --git a/crates/polars-plan/src/dsl/function_expr/row_hash.rs b/crates/polars-plan/src/dsl/function_expr/row_hash.rs index 1f4b88885eea..3a2d33f08384 100644 --- a/crates/polars-plan/src/dsl/function_expr/row_hash.rs +++ b/crates/polars-plan/src/dsl/function_expr/row_hash.rs @@ -1,6 +1,6 @@ use super::*; pub(super) fn row_hash(s: &Series, k0: u64, k1: u64, k2: u64, k3: u64) -> PolarsResult { - Ok(s.hash(ahash::RandomState::with_seeds(k0, k1, k2, k3)) + Ok(s.hash(PlRandomState::with_seeds(k0, k1, k2, k3)) .into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index a385c27820d6..15f03e6bb848 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -95,7 +95,7 @@ impl FunctionExpr { }), #[cfg(feature = "dtype-struct")] AsStruct => Ok(Field::new( - fields[0].name(), + fields[0].name().clone(), DataType::Struct(fields.to_vec()), )), #[cfg(feature = "top_k")] @@ -115,8 +115,8 @@ impl FunctionExpr { IDX_DTYPE }; DataType::Struct(vec![ - Field::new(fields[0].name().as_str(), dt.clone()), - Field::new(name, count_dt), + Field::new(fields[0].name().clone(), dt.clone()), + Field::new(name.clone(), count_dt), ]) }), #[cfg(feature = "unique_counts")] @@ -143,15 +143,18 @@ impl FunctionExpr { if *include_breakpoint || *include_category { let mut fields = Vec::with_capacity(3); if *include_breakpoint { - fields.push(Field::new("breakpoint", DataType::Float64)); + fields.push(Field::new( + PlSmallStr::from_static("breakpoint"), + DataType::Float64, + )); } if *include_category { fields.push(Field::new( - "category", + PlSmallStr::from_static("category"), DataType::Categorical(None, Default::default()), )); } - fields.push(Field::new("count", IDX_DTYPE)); + fields.push(Field::new(PlSmallStr::from_static("count"), IDX_DTYPE)); mapper.with_dtype(DataType::Struct(fields)) } else { mapper.with_dtype(IDX_DTYPE) @@ -231,8 +234,11 @@ impl FunctionExpr { .. } => { let struct_dt = DataType::Struct(vec![ - Field::new("breakpoint", DataType::Float64), - Field::new("category", DataType::Categorical(None, Default::default())), + Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64), + Field::new( + PlSmallStr::from_static("category"), + DataType::Categorical(None, Default::default()), + ), ]); mapper.with_dtype(struct_dt) }, @@ -269,16 +275,19 @@ impl FunctionExpr { .. } => { let struct_dt = DataType::Struct(vec![ - Field::new("breakpoint", DataType::Float64), - Field::new("category", DataType::Categorical(None, Default::default())), + Field::new(PlSmallStr::from_static("breakpoint"), DataType::Float64), + Field::new( + PlSmallStr::from_static("category"), + DataType::Categorical(None, Default::default()), + ), ]); mapper.with_dtype(struct_dt) }, #[cfg(feature = "rle")] RLE => mapper.map_dtype(|dt| { DataType::Struct(vec![ - Field::new("len", IDX_DTYPE), - Field::new("value", dt.clone()), + Field::new(PlSmallStr::from_static("len"), IDX_DTYPE), + Field::new(PlSmallStr::from_static("value"), dt.clone()), ]) }), #[cfg(feature = "rle")] @@ -298,7 +307,7 @@ impl FunctionExpr { MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), SumHorizontal => { - if mapper.fields[0].data_type() == &DataType::Boolean { + if mapper.fields[0].dtype() == &DataType::Boolean { mapper.with_dtype(DataType::UInt32) } else { mapper.map_to_supertype() @@ -363,13 +372,13 @@ impl<'a> FieldsMapper<'a> { /// Set a dtype. pub fn with_dtype(&self, dtype: DataType) -> PolarsResult { - Ok(Field::new(self.fields[0].name(), dtype)) + Ok(Field::new(self.fields[0].name().clone(), dtype)) } /// Map a single dtype. pub fn map_dtype(&self, func: impl FnOnce(&DataType) -> DataType) -> PolarsResult { - let dtype = func(self.fields[0].data_type()); - Ok(Field::new(self.fields[0].name(), dtype)) + let dtype = func(self.fields[0].dtype()); + Ok(Field::new(self.fields[0].name().clone(), dtype)) } pub fn get_fields_lens(&self) -> usize { @@ -416,8 +425,8 @@ impl<'a> FieldsMapper<'a> { &self, func: impl FnOnce(&DataType) -> PolarsResult, ) -> PolarsResult { - let dtype = func(self.fields[0].data_type())?; - Ok(Field::new(self.fields[0].name(), dtype)) + let dtype = func(self.fields[0].dtype())?; + Ok(Field::new(self.fields[0].name().clone(), dtype)) } /// Map all dtypes with a potentially failing mapper function. @@ -429,7 +438,7 @@ impl<'a> FieldsMapper<'a> { let dtypes = self .fields .iter() - .map(|fld| fld.data_type()) + .map(|fld| fld.dtype()) .collect::>(); let new_type = func(&dtypes)?; fld.coerce(new_type); @@ -448,7 +457,7 @@ impl<'a> FieldsMapper<'a> { pub fn map_to_list_and_array_inner_dtype(&self) -> PolarsResult { let mut first = self.fields[0].clone(); let dt = first - .data_type() + .dtype() .inner_dtype() .cloned() .unwrap_or_else(|| DataType::Unknown(Default::default())); @@ -497,7 +506,7 @@ impl<'a> FieldsMapper<'a> { let mut first = self.fields[0].clone(); use DataType::*; let dt = first - .data_type() + .dtype() .inner_dtype() .cloned() .unwrap_or_else(|| Unknown(Default::default())); @@ -511,16 +520,25 @@ impl<'a> FieldsMapper<'a> { } pub(super) fn pow_dtype(&self) -> PolarsResult { - let base_dtype = self.fields[0].data_type(); - let exponent_dtype = self.fields[1].data_type(); + let base_dtype = self.fields[0].dtype(); + let exponent_dtype = self.fields[1].dtype(); if base_dtype.is_integer() { if exponent_dtype.is_float() { - Ok(Field::new(self.fields[0].name(), exponent_dtype.clone())) + Ok(Field::new( + self.fields[0].name().clone(), + exponent_dtype.clone(), + )) } else { - Ok(Field::new(self.fields[0].name(), base_dtype.clone())) + Ok(Field::new( + self.fields[0].name().clone(), + base_dtype.clone(), + )) } } else { - Ok(Field::new(self.fields[0].name(), base_dtype.clone())) + Ok(Field::new( + self.fields[0].name().clone(), + base_dtype.clone(), + )) } } @@ -538,8 +556,8 @@ impl<'a> FieldsMapper<'a> { let new = &self.fields[2]; let default = self.fields.get(3); match default { - Some(default) => try_get_supertype(default.data_type(), new.data_type())?, - None => new.data_type().clone(), + Some(default) => try_get_supertype(default.dtype(), new.dtype())?, + None => new.dtype().clone(), } }, }; diff --git a/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs b/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs index c2a0d16d78dc..6ebc5f3d221e 100644 --- a/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs +++ b/crates/polars-plan/src/dsl/function_expr/shift_and_fill.rs @@ -106,7 +106,7 @@ pub(super) fn shift_and_fill(args: &[Series]) -> PolarsResult { dt => polars_bail!(opq = shift_and_fill, dt), } } else { - Ok(Series::full_null(s.name(), s.len(), s.dtype())) + Ok(Series::full_null(s.name().clone(), s.len(), s.dtype())) } } @@ -123,6 +123,6 @@ pub fn shift(args: &[Series]) -> PolarsResult { match n.get(0) { Some(n) => Ok(s.shift(n)), - None => Ok(Series::full_null(s.name(), s.len(), s.dtype())), + None => Ok(Series::full_null(s.name().clone(), s.len(), s.dtype())), } } diff --git a/crates/polars-plan/src/dsl/function_expr/sign.rs b/crates/polars-plan/src/dsl/function_expr/sign.rs index 41707664e3ac..a7bf4d3277e6 100644 --- a/crates/polars-plan/src/dsl/function_expr/sign.rs +++ b/crates/polars-plan/src/dsl/function_expr/sign.rs @@ -1,41 +1,34 @@ +use num::{One, Zero}; use polars_core::export::num; -use DataType::*; +use polars_core::with_match_physical_numeric_polars_type; use super::*; pub(super) fn sign(s: &Series) -> PolarsResult { - match s.dtype() { - Float32 => { - let ca = s.f32().unwrap(); - sign_float(ca) - }, - Float64 => { - let ca = s.f64().unwrap(); - sign_float(ca) - }, - dt if dt.is_numeric() => { - let s = s.cast(&Float64)?; - sign(&s) - }, - dt => polars_bail!(opq = sign, dt), - } + let dt = s.dtype(); + polars_ensure!(dt.is_numeric(), opq = sign, dt); + with_match_physical_numeric_polars_type!(dt, |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref(); + Ok(sign_impl(ca)) + }) } -fn sign_float(ca: &ChunkedArray) -> PolarsResult +fn sign_impl(ca: &ChunkedArray) -> Series where - T: PolarsFloatType, - T::Native: num::Float, + T: PolarsNumericType, ChunkedArray: IntoSeries, { - ca.apply_values(signum_improved).into_series().cast(&Int64) -} - -// Wrapper for the signum function that handles +/-0.0 inputs differently -// See discussion here: https://github.com/rust-lang/rust/issues/57543 -fn signum_improved(v: F) -> F { - if v.is_zero() { - v - } else { - v.signum() - } + ca.apply_values(|x| { + if x < T::Native::zero() { + T::Native::zero() - T::Native::one() + } else if x > T::Native::zero() { + T::Native::one() + } else { + // Returning x here ensures we return NaN for NaN input, and + // maintain the sign for signed zeroes (although we don't really + // care about the latter). + x + } + }) + .into_series() } diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 77a9f9e519bb..9a5d2a9ff537 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -7,7 +7,7 @@ use once_cell::sync::Lazy; use polars_core::chunked_array::temporal::validate_time_zone; use polars_core::utils::handle_casting_failures; #[cfg(feature = "dtype-struct")] -use polars_utils::format_smartstring; +use polars_utils::format_pl_smallstr; #[cfg(feature = "regex")] use regex::{escape, Regex}; #[cfg(feature = "serde")] @@ -25,12 +25,12 @@ static TZ_AWARE_RE: Lazy = pub enum StringFunction { #[cfg(feature = "concat_str")] ConcatHorizontal { - delimiter: String, + delimiter: PlSmallStr, ignore_nulls: bool, }, #[cfg(feature = "concat_str")] ConcatVertical { - delimiter: String, + delimiter: PlSmallStr, ignore_nulls: bool, }, #[cfg(feature = "regex")] @@ -45,7 +45,7 @@ pub enum StringFunction { #[cfg(feature = "extract_groups")] ExtractGroups { dtype: DataType, - pat: String, + pat: PlSmallStr, }, #[cfg(feature = "regex")] Find { @@ -182,13 +182,13 @@ impl StringFunction { #[cfg(feature = "dtype-struct")] SplitExact { n, .. } => mapper.with_dtype(DataType::Struct( (0..n + 1) - .map(|i| Field::from_owned(format_smartstring!("field_{i}"), DataType::String)) + .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String)) .collect(), )), #[cfg(feature = "dtype-struct")] SplitN(n) => mapper.with_dtype(DataType::Struct( (0..*n) - .map(|i| Field::from_owned(format_smartstring!("field_{i}"), DataType::String)) + .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String)) .collect(), )), #[cfg(feature = "find_many")] @@ -576,7 +576,7 @@ pub(super) fn extract_all(args: &[Series]) -> PolarsResult { ca.extract_all(pat).map(|ca| ca.into_series()) } else { Ok(Series::full_null( - ca.name(), + ca.name().clone(), ca.len(), &DataType::List(Box::new(DataType::String)), )) @@ -596,7 +596,11 @@ pub(super) fn count_matches(args: &[Series], literal: bool) -> PolarsResult), - RenameFields(Arc<[String]>), - PrefixFields(Arc), - SuffixFields(Arc), + FieldByName(PlSmallStr), + RenameFields(Arc<[PlSmallStr]>), + PrefixFields(PlSmallStr), + SuffixFields(PlSmallStr), #[cfg(feature = "json")] JsonEncode, WithFields, - MultipleFields(Arc<[ColumnName]>), + MultipleFields(Arc<[PlSmallStr]>), } impl StructFunction { @@ -38,11 +39,11 @@ impl StructFunction { if let DataType::Struct(ref fields) = field.dtype { let fld = fields .iter() - .find(|fld| fld.name() == name.as_ref()) - .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name.as_ref()))?; + .find(|fld| fld.name() == name) + .ok_or_else(|| polars_err!(StructFieldNotFound: "{}", name))?; Ok(fld.clone()) } else { - polars_bail!(StructFieldNotFound: "{}", name.as_ref()); + polars_bail!(StructFieldNotFound: "{}", name); } }), RenameFields(names) => mapper.map_dtype(|dt| match dt { @@ -50,7 +51,7 @@ impl StructFunction { let fields = fields .iter() .zip(names.as_ref()) - .map(|(fld, name)| Field::new(name, fld.data_type().clone())) + .map(|(fld, name)| Field::new(name.clone(), fld.dtype().clone())) .collect(); DataType::Struct(fields) }, @@ -60,7 +61,7 @@ impl StructFunction { dt => DataType::Struct( names .iter() - .map(|name| Field::new(name, dt.clone())) + .map(|name| Field::new(name.clone(), dt.clone())) .collect(), ), }), @@ -70,7 +71,7 @@ impl StructFunction { .iter() .map(|fld| { let name = fld.name(); - Field::new(&format!("{prefix}{name}"), fld.data_type().clone()) + Field::new(format_pl_smallstr!("{prefix}{name}"), fld.dtype().clone()) }) .collect(); Ok(DataType::Struct(fields)) @@ -83,7 +84,7 @@ impl StructFunction { .iter() .map(|fld| { let name = fld.name(); - Field::new(&format!("{name}{suffix}"), fld.data_type().clone()) + Field::new(format_pl_smallstr!("{name}{suffix}"), fld.dtype().clone()) }) .collect(); Ok(DataType::Struct(fields)) @@ -96,26 +97,26 @@ impl StructFunction { let args = mapper.args(); let struct_ = &args[0]; - if let DataType::Struct(fields) = struct_.data_type() { + if let DataType::Struct(fields) = struct_.dtype() { let mut name_2_dtype = PlIndexMap::with_capacity(fields.len() * 2); for field in fields { - name_2_dtype.insert(field.name(), field.data_type()); + name_2_dtype.insert(field.name(), field.dtype()); } for arg in &args[1..] { - name_2_dtype.insert(arg.name(), arg.data_type()); + name_2_dtype.insert(arg.name(), arg.dtype()); } let dtype = DataType::Struct( name_2_dtype .iter() - .map(|(name, dtype)| Field::new(name, (*dtype).clone())) + .map(|(&name, &dtype)| Field::new(name.clone(), dtype.clone())) .collect(), ); let mut out = struct_.clone(); out.coerce(dtype); Ok(out) } else { - let dt = struct_.data_type(); + let dt = struct_.dtype(); polars_bail!(op = "with_fields", got = dt, expected = "Struct") } }, @@ -146,10 +147,10 @@ impl From for SpecialEq> { use StructFunction::*; match func { FieldByIndex(_) => panic!("should be replaced"), - FieldByName(name) => map!(get_by_name, name.clone()), + FieldByName(name) => map!(get_by_name, &name), RenameFields(names) => map!(rename_fields, names.clone()), - PrefixFields(prefix) => map!(prefix_fields, prefix.clone()), - SuffixFields(suffix) => map!(suffix_fields, suffix.clone()), + PrefixFields(prefix) => map!(prefix_fields, prefix.as_str()), + SuffixFields(suffix) => map!(suffix_fields, suffix.as_str()), #[cfg(feature = "json")] JsonEncode => map!(to_json), WithFields => map_as_slice!(with_fields), @@ -158,12 +159,12 @@ impl From for SpecialEq> { } } -pub(super) fn get_by_name(s: &Series, name: Arc) -> PolarsResult { +pub(super) fn get_by_name(s: &Series, name: &str) -> PolarsResult { let ca = s.struct_()?; - ca.field_by_name(name.as_ref()) + ca.field_by_name(name) } -pub(super) fn rename_fields(s: &Series, names: Arc<[String]>) -> PolarsResult { +pub(super) fn rename_fields(s: &Series, names: Arc<[PlSmallStr]>) -> PolarsResult { let ca = s.struct_()?; let fields = ca .fields_as_series() @@ -171,16 +172,16 @@ pub(super) fn rename_fields(s: &Series, names: Arc<[String]>) -> PolarsResult>(); - let mut out = StructChunked::from_series(ca.name(), &fields)?; + let mut out = StructChunked::from_series(ca.name().clone(), &fields)?; out.zip_outer_validity(ca); Ok(out.into_series()) } -pub(super) fn prefix_fields(s: &Series, prefix: Arc) -> PolarsResult { +pub(super) fn prefix_fields(s: &Series, prefix: &str) -> PolarsResult { let ca = s.struct_()?; let fields = ca .fields_as_series() @@ -188,16 +189,16 @@ pub(super) fn prefix_fields(s: &Series, prefix: Arc) -> PolarsResult>(); - let mut out = StructChunked::from_series(ca.name(), &fields)?; + let mut out = StructChunked::from_series(ca.name().clone(), &fields)?; out.zip_outer_validity(ca); Ok(out.into_series()) } -pub(super) fn suffix_fields(s: &Series, suffix: Arc) -> PolarsResult { +pub(super) fn suffix_fields(s: &Series, suffix: &str) -> PolarsResult { let ca = s.struct_()?; let fields = ca .fields_as_series() @@ -205,11 +206,11 @@ pub(super) fn suffix_fields(s: &Series, suffix: Arc) -> PolarsResult>(); - let mut out = StructChunked::from_series(ca.name(), &fields)?; + let mut out = StructChunked::from_series(ca.name().clone(), &fields)?; out.zip_outer_validity(ca); Ok(out.into_series()) } @@ -224,7 +225,7 @@ pub(super) fn to_json(s: &Series) -> PolarsResult { polars_json::json::write::serialize_to_utf8(arr.as_ref()) }); - Ok(StringChunked::from_chunk_iter(ca.name(), iter).into_series()) + Ok(StringChunked::from_chunk_iter(ca.name().clone(), iter).into_series()) } pub(super) fn with_fields(args: &[Series]) -> PolarsResult { @@ -244,7 +245,7 @@ pub(super) fn with_fields(args: &[Series]) -> PolarsResult { } let new_fields = fields.into_values().cloned().collect::>(); - let mut out = StructChunked::from_series(ca.name(), &new_fields)?; + let mut out = StructChunked::from_series(ca.name().clone(), &new_fields)?; out.zip_outer_validity(ca); Ok(out.into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/temporal.rs b/crates/polars-plan/src/dsl/function_expr/temporal.rs index 41fbb2d737d8..18340a00adaf 100644 --- a/crates/polars-plan/src/dsl/function_expr/temporal.rs +++ b/crates/polars-plan/src/dsl/function_expr/temporal.rs @@ -14,7 +14,7 @@ impl From for SpecialEq> { Quarter => map!(datetime::quarter), Week => map!(datetime::week), WeekDay => map!(datetime::weekday), - Duration(tu) => map_as_slice!(datetime::duration, tu), + Duration(tu) => map_as_slice!(impl_duration, tu), Day => map!(datetime::day), OrdinalDay => map!(datetime::ordinal_day), Time => map!(datetime::time), @@ -178,7 +178,7 @@ pub(super) fn datetime( }; let mut s = ca.into_series(); - s.rename("datetime"); + s.rename(PlSmallStr::from_static("datetime")); Ok(s) } diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs index 6f420c72f768..d15b1769cf3a 100644 --- a/crates/polars-plan/src/dsl/functions/concat.rs +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -4,7 +4,7 @@ use super::*; /// Horizontally concat string columns in linear time pub fn concat_str>(s: E, separator: &str, ignore_nulls: bool) -> Expr { let input = s.as_ref().to_vec(); - let separator = separator.to_string(); + let separator = separator.into(); Expr::Function { input, diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index bb0fc5aa3cf1..dd7521ad20a9 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -79,11 +79,15 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; + let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) + .then(lit(1.0)) + .otherwise(lit(Null {})); + let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = x.clone().rolling_mean(rolling_options.clone()); - let mean_y = y.clone().rolling_mean(rolling_options.clone()); - let var_x = x.clone().rolling_var(rolling_options.clone()); - let var_y = y.clone().rolling_var(rolling_options); + let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); + let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); + let var_x = (x.clone() * non_null_mask.clone()).rolling_var(rolling_options.clone()); + let var_y = (y.clone() * non_null_mask.clone()).rolling_var(rolling_options); let rolling_options_count = RollingOptionsFixedWindow { window_size: options.window_size as usize, @@ -110,9 +114,13 @@ pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; + let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) + .then(lit(1.0)) + .otherwise(lit(Null {})); + let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = x.clone().rolling_mean(rolling_options.clone()); - let mean_y = y.clone().rolling_mean(rolling_options); + let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); + let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options); let rolling_options_count = RollingOptionsFixedWindow { window_size: options.window_size as usize, min_periods: 0, diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 1b49791ebc26..eb0c79b3b0f7 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -8,11 +8,11 @@ fn cum_fold_dtype() -> GetOutput { st = get_supertype(&st, &fld.dtype).unwrap(); } Ok(Field::new( - &fields[0].name, + fields[0].name.clone(), DataType::Struct( fields .iter() - .map(|fld| Field::new(fld.name(), st.clone())) + .map(|fld| Field::new(fld.name().clone(), st.clone())) .collect(), ), )) @@ -118,15 +118,16 @@ where let mut result = vec![acc.clone()]; for s in s_iter { - let name = s.name().to_string(); + let name = s.name().clone(); if let Some(a) = f(acc.clone(), s.clone())? { acc = a; } - acc.rename(&name); + acc.rename(name); result.push(acc.clone()); } - StructChunked::from_series(acc.name(), &result).map(|ca| Some(ca.into_series())) + StructChunked::from_series(acc.name().clone(), &result) + .map(|ca| Some(ca.into_series())) }, None => Err(polars_err!(ComputeError: "`reduce` did not have any expressions to fold")), } @@ -167,15 +168,15 @@ where } for s in series { - let name = s.name().to_string(); + let name = s.name().clone(); if let Some(a) = f(acc.clone(), s)? { acc = a; - acc.rename(&name); + acc.rename(name); result.push(acc.clone()); } } - StructChunked::from_series(acc.name(), &result).map(|ca| Some(ca.into_series())) + StructChunked::from_series(acc.name().clone(), &result).map(|ca| Some(ca.into_series())) }) as Arc); Expr::AnonymousFunction { diff --git a/crates/polars-plan/src/dsl/functions/index.rs b/crates/polars-plan/src/dsl/functions/index.rs index d125ce571307..a3c840125181 100644 --- a/crates/polars-plan/src/dsl/functions/index.rs +++ b/crates/polars-plan/src/dsl/functions/index.rs @@ -1,6 +1,7 @@ use super::*; /// Find the indexes that would sort these series in order of appearance. +/// /// That means that the first `Series` will be used to determine the ordering /// until duplicates are found. Once duplicates are found, the next `Series` will /// be used and so on. @@ -10,7 +11,7 @@ pub fn arg_sort_by>(by: E, sort_options: SortMultipleOptions) - let name = expr_output_name(e).unwrap(); int_range(lit(0 as IdxSize), len().cast(IDX_DTYPE), 1, IDX_DTYPE) .sort_by(by, sort_options) - .alias(name.as_ref()) + .alias(name) } #[cfg(feature = "arg_where")] diff --git a/crates/polars-plan/src/dsl/functions/mod.rs b/crates/polars-plan/src/dsl/functions/mod.rs index 8b8fe24c7163..9219704bfc2f 100644 --- a/crates/polars-plan/src/dsl/functions/mod.rs +++ b/crates/polars-plan/src/dsl/functions/mod.rs @@ -17,6 +17,7 @@ mod range; mod repeat; mod selectors; mod syntactic_sugar; +#[cfg(feature = "temporal")] mod temporal; pub use arity::*; @@ -41,6 +42,7 @@ pub use range::*; pub use repeat::*; pub use selectors::*; pub use syntactic_sugar::*; +#[cfg(feature = "temporal")] pub use temporal::*; #[cfg(feature = "arg_where")] diff --git a/crates/polars-plan/src/dsl/functions/repeat.rs b/crates/polars-plan/src/dsl/functions/repeat.rs index 1b32abc97b5a..5c3084fb7caf 100644 --- a/crates/polars-plan/src/dsl/functions/repeat.rs +++ b/crates/polars-plan/src/dsl/functions/repeat.rs @@ -1,8 +1,9 @@ use super::*; -/// Create a column of length `n` containing `n` copies of the literal `value`. Generally you won't need this function, -/// as `lit(value)` already represents a column containing only `value` whose length is automatically set to the correct -/// number of rows. +/// Create a column of length `n` containing `n` copies of the literal `value`. +/// +/// Generally you won't need this function, as `lit(value)` already represents a column containing +/// only `value` whose length is automatically set to the correct number of rows. pub fn repeat>(value: E, n: Expr) -> Expr { let function = |s: Series, n: Series| { polars_ensure!( @@ -15,5 +16,6 @@ pub fn repeat>(value: E, n: Expr) -> Expr { )?; Ok(Some(s.new_from_index(0, n))) }; - apply_binary(value.into(), n, function, GetOutput::same_type()).alias("repeat") + apply_binary(value.into(), n, function, GetOutput::same_type()) + .alias(PlSmallStr::from_static("repeat")) } diff --git a/crates/polars-plan/src/dsl/functions/selectors.rs b/crates/polars-plan/src/dsl/functions/selectors.rs index 11c92a40b1ba..28d52c10f835 100644 --- a/crates/polars-plan/src/dsl/functions/selectors.rs +++ b/crates/polars-plan/src/dsl/functions/selectors.rs @@ -24,10 +24,14 @@ use super::*; /// // only if regex features is activated /// col("^foo.*$") /// ``` -pub fn col(name: &str) -> Expr { - match name { +pub fn col(name: S) -> Expr +where + S: Into, +{ + let name = name.into(); + match name.as_str() { "*" => Expr::Wildcard, - _ => Expr::Column(ColumnName::from(name)), + _ => Expr::Column(name), } } @@ -37,12 +41,12 @@ pub fn all() -> Expr { } /// Select multiple columns by name. -pub fn cols>(names: I) -> Expr { - let names = names.into_vec(); - let names = names - .into_iter() - .map(|v| ColumnName::from(v.as_str())) - .collect(); +pub fn cols(names: I) -> Expr +where + I: IntoIterator, + S: Into, +{ + let names = names.into_iter().map(|x| x.into()).collect(); Expr::Columns(names) } diff --git a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs index 5e1e45a0124c..e1ef64ee02ec 100644 --- a/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs +++ b/crates/polars-plan/src/dsl/functions/syntactic_sugar.rs @@ -57,10 +57,10 @@ pub fn is_not_null(expr: Expr) -> Expr { /// Follows the rules of Rust casting, with the exception that integers and floats can be cast to `DataType::Date` and /// `DataType::DateTime(_, _)`. A column consisting entirely of of `Null` can be cast to any type, regardless of the /// nominal type of the column. -pub fn cast(expr: Expr, data_type: DataType) -> Expr { +pub fn cast(expr: Expr, dtype: DataType) -> Expr { Expr::Cast { expr: Arc::new(expr), - data_type, + dtype, options: CastOptions::NonStrict, } } diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index 0289e69c6514..145b521092d3 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -1,3 +1,5 @@ +use chrono::{Datelike, Timelike}; + use super::*; macro_rules! impl_unit_setter { @@ -107,11 +109,83 @@ impl DatetimeArgs { pub fn with_ambiguous(self, ambiguous: Expr) -> Self { Self { ambiguous, ..self } } + + fn all_literal(&self) -> bool { + use Expr::*; + [ + &self.year, + &self.month, + &self.day, + &self.hour, + &self.minute, + &self.second, + &self.microsecond, + ] + .iter() + .all(|e| matches!(e, Literal(_))) + } + + fn as_literal(&self) -> Option { + if self.time_zone.is_some() || !self.all_literal() { + return None; + }; + let Expr::Literal(lv) = &self.year else { + unreachable!() + }; + let year = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.month else { + unreachable!() + }; + let month = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.day else { + unreachable!() + }; + let day = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.hour else { + unreachable!() + }; + let hour = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.minute else { + unreachable!() + }; + let minute = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.second else { + unreachable!() + }; + let second = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.microsecond else { + unreachable!() + }; + let ms: u32 = lv.to_any_value()?.extract()?; + + let dt = chrono::NaiveDateTime::default() + .with_year(year)? + .with_month(month)? + .with_day(day)? + .with_hour(hour)? + .with_minute(minute)? + .with_second(second)? + .with_nanosecond(ms * 1000)?; + + let ts = match self.time_unit { + TimeUnit::Milliseconds => dt.and_utc().timestamp_millis(), + TimeUnit::Microseconds => dt.and_utc().timestamp_micros(), + TimeUnit::Nanoseconds => dt.and_utc().timestamp_nanos_opt()?, + }; + + Some( + Expr::Literal(LiteralValue::DateTime(ts, self.time_unit, None)) + .alias(PlSmallStr::from_static("datetime")), + ) + } } /// Construct a column of `Datetime` from the provided [`DatetimeArgs`]. -#[cfg(feature = "temporal")] pub fn datetime(args: DatetimeArgs) -> Expr { + if let Some(e) = args.as_literal() { + return e; + } + let year = args.year; let month = args.month; let day = args.day; @@ -253,11 +327,88 @@ impl DurationArgs { impl_unit_setter!(with_milliseconds(milliseconds)); impl_unit_setter!(with_microseconds(microseconds)); impl_unit_setter!(with_nanoseconds(nanoseconds)); + + fn all_literal(&self) -> bool { + use Expr::*; + [ + &self.weeks, + &self.days, + &self.hours, + &self.seconds, + &self.minutes, + &self.milliseconds, + &self.microseconds, + &self.nanoseconds, + ] + .iter() + .all(|e| matches!(e, Literal(_))) + } + + fn as_literal(&self) -> Option { + if !self.all_literal() { + return None; + }; + let Expr::Literal(lv) = &self.weeks else { + unreachable!() + }; + let weeks = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.days else { + unreachable!() + }; + let days = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.hours else { + unreachable!() + }; + let hours = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.seconds else { + unreachable!() + }; + let seconds = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.minutes else { + unreachable!() + }; + let minutes = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.milliseconds else { + unreachable!() + }; + let milliseconds = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.microseconds else { + unreachable!() + }; + let microseconds = lv.to_any_value()?.extract()?; + let Expr::Literal(lv) = &self.nanoseconds else { + unreachable!() + }; + let nanoseconds = lv.to_any_value()?.extract()?; + + type D = chrono::Duration; + let delta = D::weeks(weeks) + + D::days(days) + + D::hours(hours) + + D::seconds(seconds) + + D::minutes(minutes) + + D::milliseconds(milliseconds) + + D::microseconds(microseconds) + + D::nanoseconds(nanoseconds); + + let d = match self.time_unit { + TimeUnit::Milliseconds => delta.num_milliseconds(), + TimeUnit::Microseconds => delta.num_microseconds()?, + TimeUnit::Nanoseconds => delta.num_nanoseconds()?, + }; + + Some( + Expr::Literal(LiteralValue::Duration(d, self.time_unit)) + .alias(PlSmallStr::from_static("duration")), + ) + } } /// Construct a column of [`Duration`] from the provided [`DurationArgs`] -#[cfg(feature = "temporal")] pub fn duration(args: DurationArgs) -> Expr { + if let Some(e) = args.as_literal() { + return e; + } Expr::Function { input: vec![ args.weeks, diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 2dc0c72c1872..11e825a7ec1f 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -5,6 +5,8 @@ use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; #[cfg(feature = "list_sets")] +use polars_core::utils::SuperTypeFlags; +#[cfg(feature = "list_sets")] use polars_core::utils::SuperTypeOptions; use crate::prelude::function_expr::ListFunction; @@ -51,7 +53,7 @@ impl ListNameSpace { }), &[n], false, - false, + None, ) } @@ -72,7 +74,7 @@ impl ListNameSpace { }), &[fraction], false, - false, + None, ) } @@ -158,7 +160,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::Get(null_on_oob)), &[index], false, - false, + None, ) } @@ -173,7 +175,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::Gather(null_on_oob)), &[index], false, - false, + None, ) } @@ -183,7 +185,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::GatherEvery), &[n, offset], false, - false, + None, ) } @@ -205,7 +207,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::Join(ignore_nulls)), &[separator], false, - false, + None, ) } @@ -237,7 +239,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::Shift), &[periods], false, - false, + None, ) } @@ -247,7 +249,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::Slice), &[offset, length], false, - false, + None, ) } @@ -311,7 +313,7 @@ impl ListNameSpace { let fields = (0..upper_bound) .map(|i| { let name = _default_struct_name_gen(i); - Field::from_owned(name, inner.clone()) + Field::new(name, inner.clone()) }) .collect(); let dt = DataType::Struct(fields); @@ -335,7 +337,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::Contains), &[other], false, - false, + None, ) .with_function_options(|mut options| { options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; @@ -352,7 +354,7 @@ impl ListNameSpace { FunctionExpr::ListExpr(ListFunction::CountMatches), &[other], false, - false, + None, ) .with_function_options(|mut options| { options.flags |= FunctionFlags::INPUT_WILDCARD_EXPANSION; @@ -367,7 +369,9 @@ impl ListNameSpace { function: FunctionExpr::ListExpr(ListFunction::SetOperation(set_operation)), options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - cast_to_supertypes: Some(SuperTypeOptions { implode_list: true }), + cast_to_supertypes: Some(SuperTypeOptions { + flags: SuperTypeFlags::default() | SuperTypeFlags::ALLOW_IMPLODE_LIST, + }), flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION & !FunctionFlags::RETURNS_SCALAR, ..Default::default() diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs index cfff5da07076..0e7a30fa024b 100644 --- a/crates/polars-plan/src/dsl/meta.rs +++ b/crates/polars-plan/src/dsl/meta.rs @@ -11,32 +11,33 @@ pub struct MetaNameSpace(pub(crate) Expr); impl MetaNameSpace { /// Pop latest expression and return the input(s) of the popped expression. - pub fn pop(self) -> Vec { + pub fn pop(self) -> PolarsResult> { let mut arena = Arena::with_capacity(8); - let node = to_aexpr(self.0, &mut arena); + let node = to_aexpr(self.0, &mut arena)?; let ae = arena.get(node); let mut inputs = Vec::with_capacity(2); ae.nodes(&mut inputs); - inputs + Ok(inputs .iter() .map(|node| node_to_expr(*node, &arena)) - .collect() + .collect()) } /// Get the root column names. - pub fn root_names(&self) -> Vec> { + pub fn root_names(&self) -> Vec { expr_to_leaf_column_names(&self.0) } /// A projection that only takes a column or a column + alias. pub fn is_simple_projection(&self) -> bool { let mut arena = Arena::with_capacity(8); - let node = to_aexpr(self.0.clone(), &mut arena); - aexpr_is_simple_projection(node, &arena) + to_aexpr(self.0.clone(), &mut arena) + .map(|node| aexpr_is_simple_projection(node, &arena)) + .unwrap_or(false) } /// Get the output name of this expression. - pub fn output_name(&self) -> PolarsResult> { + pub fn output_name(&self) -> PolarsResult { expr_output_name(&self.0) } @@ -160,7 +161,7 @@ impl MetaNameSpace { /// the expression as a tree pub fn into_tree_formatter(self) -> PolarsResult { let mut arena = Default::default(); - let node = to_aexpr(self.0, &mut arena); + let node = to_aexpr(self.0, &mut arena)?; let mut visitor = TreeFmtVisitor::default(); AexprNode::new(node).visit(&mut visitor, &arena)?; diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index c5cf50e1483c..895020ce43f5 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -62,7 +62,9 @@ use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use polars_core::series::IsSorted; -use polars_core::utils::try_get_supertype; +#[cfg(any(feature = "search_sorted", feature = "is_between"))] +use polars_core::utils::SuperTypeFlags; +use polars_core::utils::{try_get_supertype, SuperTypeOptions}; pub use selector::Selector; #[cfg(feature = "dtype-struct")] pub use struct_::*; @@ -168,8 +170,11 @@ impl Expr { } /// Rename Column. - pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Arc::new(self), ColumnName::from(name)) + pub fn alias(self, name: S) -> Expr + where + S: Into, + { + Expr::Alias(Arc::new(self), name.into()) } /// Run is_null operation on `Expr`. @@ -320,7 +325,7 @@ impl Expr { self.function_with_options( move |s: Series| { Ok(Some(Series::new( - s.name(), + s.name().clone(), &[s.arg_min().map(|idx| idx as u32)], ))) }, @@ -341,7 +346,7 @@ impl Expr { self.function_with_options( move |s: Series| { Ok(Some(Series::new( - s.name(), + s.name().clone(), &[s.arg_max().map(|idx| idx as IdxSize)], ))) }, @@ -376,7 +381,9 @@ impl Expr { collect_groups: ApplyOptions::GroupWise, flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR, fmt_str: "search_sorted", - cast_to_supertypes: Some(Default::default()), + cast_to_supertypes: Some( + (SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into(), + ), ..Default::default() }, } @@ -384,28 +391,28 @@ impl Expr { /// Cast expression to another data type. /// Throws an error if conversion had overflows. - pub fn strict_cast(self, data_type: DataType) -> Self { + pub fn strict_cast(self, dtype: DataType) -> Self { Expr::Cast { expr: Arc::new(self), - data_type, + dtype, options: CastOptions::Strict, } } /// Cast expression to another data type. - pub fn cast(self, data_type: DataType) -> Self { + pub fn cast(self, dtype: DataType) -> Self { Expr::Cast { expr: Arc::new(self), - data_type, + dtype, options: CastOptions::NonStrict, } } /// Cast expression to another data type. - pub fn cast_with_options(self, data_type: DataType, cast_options: CastOptions) -> Self { + pub fn cast_with_options(self, dtype: DataType, cast_options: CastOptions) -> Self { Expr::Cast { expr: Arc::new(self), - data_type, + dtype, options: cast_options, } } @@ -722,17 +729,12 @@ impl Expr { function_expr: FunctionExpr, arguments: &[Expr], returns_scalar: bool, - cast_to_supertypes: bool, + cast_to_supertypes: Option, ) -> Self { let mut input = Vec::with_capacity(arguments.len() + 1); input.push(self); input.extend_from_slice(arguments); - let cast_to_supertypes = if cast_to_supertypes { - Some(Default::default()) - } else { - None - }; let mut flags = FunctionFlags::default(); if returns_scalar { flags |= FunctionFlags::RETURNS_SCALAR; @@ -827,7 +829,9 @@ impl Expr { }; self.function_with_options( - move |s: Series| Some(s.product().map(|sc| sc.into_series(s.name()))).transpose(), + move |s: Series| { + Some(s.product().map(|sc| sc.into_series(s.name().clone()))).transpose() + }, GetOutput::map_dtype(|dt| { use DataType as T; Ok(match dt { @@ -891,7 +895,7 @@ impl Expr { }, &[min, max], false, - false, + None, ) } @@ -905,7 +909,7 @@ impl Expr { }, &[max], false, - false, + None, ) } @@ -919,7 +923,7 @@ impl Expr { }, &[min], false, - false, + None, ) } @@ -1020,7 +1024,7 @@ impl Expr { pub fn rolling(self, options: RollingGroupOptions) -> Self { // We add the index column as `partition expr` so that the optimizer will // not ignore it. - let index_col = col(options.index_column.as_str()); + let index_col = col(options.index_column.clone()); Expr::Window { function: Arc::new(self), partition_by: vec![index_col], @@ -1086,7 +1090,7 @@ impl Expr { BooleanFunction::IsBetween { closed }.into(), &[lower.into(), upper.into()], false, - true, + Some((SuperTypeFlags::default() & !SuperTypeFlags::ALLOW_PRIMITIVE_TO_STRING).into()), ) } @@ -1163,7 +1167,7 @@ impl Expr { BooleanFunction::IsIn.into(), arguments, returns_scalar, - true, + Some(Default::default()), ) } else { self.apply_many_private( @@ -1258,12 +1262,8 @@ impl Expr { /// Exclude a column from a wildcard/regex selection. /// /// You may also use regexes in the exclude as long as they start with `^` and end with `$`/ - pub fn exclude(self, columns: impl IntoVec) -> Expr { - let v = columns - .into_vec() - .into_iter() - .map(|s| Excluded::Name(ColumnName::from(s))) - .collect(); + pub fn exclude(self, columns: impl IntoVec) -> Expr { + let v = columns.into_vec().into_iter().map(Excluded::Name).collect(); Expr::Exclude(Arc::new(self), v) } @@ -1499,10 +1499,10 @@ impl Expr { } }, GetOutput::map_field(|field| { - Ok(match field.data_type() { + Ok(match field.dtype() { DataType::Float64 => field.clone(), - DataType::Float32 => Field::new(field.name(), DataType::Float32), - _ => Field::new(field.name(), DataType::Float64), + DataType::Float32 => Field::new(field.name().clone(), DataType::Float32), + _ => Field::new(field.name().clone(), DataType::Float64), }) }), ) @@ -1537,7 +1537,7 @@ impl Expr { let args = [old, new]; if literal_searchers { - self.map_many_private(FunctionExpr::Replace, &args, false, false) + self.map_many_private(FunctionExpr::Replace, &args, false, None) } else { self.apply_many_private(FunctionExpr::Replace, &args, false, false) } @@ -1568,7 +1568,7 @@ impl Expr { FunctionExpr::ReplaceStrict { return_dtype }, &args, false, - false, + None, ) } else { self.apply_many_private( @@ -1585,13 +1585,13 @@ impl Expr { pub fn cut( self, breaks: Vec, - labels: Option>, + labels: Option>, left_closed: bool, include_breaks: bool, ) -> Expr { self.apply_private(FunctionExpr::Cut { breaks, - labels, + labels: labels.map(|x| x.into_vec()), left_closed, include_breaks, }) @@ -1606,14 +1606,14 @@ impl Expr { pub fn qcut( self, probs: Vec, - labels: Option>, + labels: Option>, left_closed: bool, allow_duplicates: bool, include_breaks: bool, ) -> Expr { self.apply_private(FunctionExpr::QCut { probs, - labels, + labels: labels.map(|x| x.into_vec()), left_closed, allow_duplicates, include_breaks, @@ -1629,7 +1629,7 @@ impl Expr { pub fn qcut_uniform( self, n_bins: usize, - labels: Option>, + labels: Option>, left_closed: bool, allow_duplicates: bool, include_breaks: bool, @@ -1637,7 +1637,7 @@ impl Expr { let probs = (1..n_bins).map(|b| b as f64 / n_bins as f64).collect(); self.apply_private(FunctionExpr::QCut { probs, - labels, + labels: labels.map(|x| x.into_vec()), left_closed, allow_duplicates, include_breaks, @@ -1708,12 +1708,20 @@ impl Expr { /// Get maximal value that could be hold by this dtype. pub fn upper_bound(self) -> Expr { - self.map_private(FunctionExpr::UpperBound) + self.apply_private(FunctionExpr::UpperBound) + .with_function_options(|mut options| { + options.flags |= FunctionFlags::RETURNS_SCALAR; + options + }) } /// Get minimal value that could be hold by this dtype. pub fn lower_bound(self) -> Expr { - self.map_private(FunctionExpr::LowerBound) + self.apply_private(FunctionExpr::LowerBound) + .with_function_options(|mut options| { + options.flags |= FunctionFlags::RETURNS_SCALAR; + options + }) } pub fn reshape(self, dimensions: &[i64], nested_type: NestedType) -> Self { @@ -1790,11 +1798,11 @@ impl Expr { #[cfg(feature = "dtype-struct")] /// Count all unique values and create a struct mapping value to count. /// (Note that it is better to turn parallel off in the aggregation context). - pub fn value_counts(self, sort: bool, parallel: bool, name: String, normalize: bool) -> Self { + pub fn value_counts(self, sort: bool, parallel: bool, name: &str, normalize: bool) -> Self { self.apply_private(FunctionExpr::ValueCounts { sort, parallel, - name, + name: name.into(), normalize, }) .with_function_options(|mut opts| { diff --git a/crates/polars-plan/src/dsl/name.rs b/crates/polars-plan/src/dsl/name.rs index ab7231b2e151..70bbc830b3c0 100644 --- a/crates/polars-plan/src/dsl/name.rs +++ b/crates/polars-plan/src/dsl/name.rs @@ -1,5 +1,6 @@ +use polars_utils::format_pl_smallstr; #[cfg(feature = "dtype-struct")] -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use super::*; @@ -27,7 +28,7 @@ impl ExprNameNameSpace { /// Define an alias by mapping a function over the original root column name. pub fn map(self, function: F) -> Expr where - F: Fn(&str) -> PolarsResult + 'static + Send + Sync, + F: Fn(&PlSmallStr) -> PolarsResult + 'static + Send + Sync, { let function = SpecialEq::new(Arc::new(function) as Arc); Expr::RenameAlias { @@ -39,25 +40,25 @@ impl ExprNameNameSpace { /// Add a prefix to the root column name. pub fn prefix(self, prefix: &str) -> Expr { let prefix = prefix.to_string(); - self.map(move |name| Ok(format!("{prefix}{name}"))) + self.map(move |name| Ok(format_pl_smallstr!("{prefix}{name}"))) } /// Add a suffix to the root column name. pub fn suffix(self, suffix: &str) -> Expr { let suffix = suffix.to_string(); - self.map(move |name| Ok(format!("{name}{suffix}"))) + self.map(move |name| Ok(format_pl_smallstr!("{name}{suffix}"))) } /// Update the root column name to use lowercase characters. #[allow(clippy::wrong_self_convention)] pub fn to_lowercase(self) -> Expr { - self.map(move |name| Ok(name.to_lowercase())) + self.map(move |name| Ok(PlSmallStr::from_string(name.to_lowercase()))) } /// Update the root column name to use uppercase characters. #[allow(clippy::wrong_self_convention)] pub fn to_uppercase(self) -> Expr { - self.map(move |name| Ok(name.to_uppercase())) + self.map(move |name| Ok(PlSmallStr::from_string(name.to_uppercase()))) } #[cfg(feature = "dtype-struct")] @@ -71,11 +72,11 @@ impl ExprNameNameSpace { .iter() .map(|fd| { let mut fd = fd.clone(); - fd.rename(&function(fd.name())); + fd.rename(function(fd.name())); fd }) .collect::>(); - let mut out = StructChunked::from_series(s.name(), &fields)?; + let mut out = StructChunked::from_series(s.name().clone(), &fields)?; out.zip_outer_validity(s); Ok(Some(out.into_series())) }, @@ -83,7 +84,7 @@ impl ExprNameNameSpace { DataType::Struct(fds) => { let fields = fds .iter() - .map(|fd| Field::new(&f(fd.name()), fd.data_type().clone())) + .map(|fd| Field::new(f(fd.name()), fd.dtype().clone())) .collect(); Ok(DataType::Struct(fields)) }, @@ -96,7 +97,7 @@ impl ExprNameNameSpace { pub fn prefix_fields(self, prefix: &str) -> Expr { self.0 .map_private(FunctionExpr::StructExpr(StructFunction::PrefixFields( - ColumnName::from(prefix), + PlSmallStr::from_str(prefix), ))) } @@ -104,10 +105,10 @@ impl ExprNameNameSpace { pub fn suffix_fields(self, suffix: &str) -> Expr { self.0 .map_private(FunctionExpr::StructExpr(StructFunction::SuffixFields( - ColumnName::from(suffix), + PlSmallStr::from_str(suffix), ))) } } #[cfg(feature = "dtype-struct")] -pub type FieldsNameMapper = Arc SmartString + Send + Sync>; +pub type FieldsNameMapper = Arc PlSmallStr + Send + Sync>; diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index 4fb128783cac..a4d9ae84cd73 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -1,10 +1,13 @@ use polars_ops::prelude::{JoinArgs, JoinType}; #[cfg(feature = "dynamic_group_by")] use polars_time::RollingGroupOptions; +use polars_utils::pl_str::PlSmallStr; use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use crate::dsl::Selector; + #[derive(Copy, Clone, PartialEq, Debug, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingCovOptions { @@ -17,7 +20,7 @@ pub struct RollingCovOptions { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct StrptimeOptions { /// Formatting string - pub format: Option, + pub format: Option, /// If set then polars will return an error if any date parsing fails pub strict: bool, /// If polars may parse matches that not contain the whole string @@ -105,3 +108,12 @@ pub enum NestedType { Array, List, } + +#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UnpivotArgsDSL { + pub on: Vec, + pub index: Vec, + pub variable_name: Option, + pub value_name: Option, +} diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 5fcd5e9b797a..b105f62df482 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -5,6 +5,7 @@ use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::DataFrame; use polars_core::prelude::Series; +use polars_core::schema::Schema; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; @@ -17,14 +18,14 @@ use super::expr_dyn_fn::*; use crate::constants::MAP_LIST_NAME; use crate::prelude::*; -// Will be overwritten on python polar start up. +// Will be overwritten on Python Polars start up. pub static mut CALL_SERIES_UDF_PYTHON: Option< fn(s: Series, lambda: &PyObject) -> PolarsResult, > = None; pub static mut CALL_DF_UDF_PYTHON: Option< fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, > = None; -pub(super) const MAGIC_BYTE_MARK: &[u8] = "POLARS_PYTHON_UDF".as_bytes(); +pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); #[derive(Clone, Debug)] pub struct PythonFunction(pub PyObject); @@ -141,7 +142,7 @@ impl PythonUdfExpression { .unwrap(); let arg = (PyBytes::new_bound(py, remainder),); let python_function = pickle.call1(arg).map_err(from_pyerr)?; - Ok(Arc::new(PythonUdfExpression::new( + Ok(Arc::new(Self::new( python_function.into(), output_type, is_elementwise, @@ -218,7 +219,7 @@ impl SeriesUdf for PythonUdfExpression { let output_type = self.output_type.clone(); Some(GetOutput::map_field(move |fld| { Ok(match output_type { - Some(ref dt) => Field::new(fld.name(), dt.clone()), + Some(ref dt) => Field::new(fld.name().clone(), dt.clone()), None => { let mut fld = fld.clone(); fld.coerce(DataType::Unknown(Default::default())); @@ -229,6 +230,54 @@ impl SeriesUdf for PythonUdfExpression { } } +/// Serializable version of [`GetOutput`] for Python UDFs. +pub struct PythonGetOutput { + return_dtype: Option, +} + +impl PythonGetOutput { + pub fn new(return_dtype: Option) -> Self { + Self { return_dtype } + } + + #[cfg(feature = "serde")] + pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { + // Skip header. + debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); + let buf = &buf[MAGIC_BYTE_MARK.len()..]; + + let mut reader = Cursor::new(buf); + let return_dtype: Option = + ciborium::de::from_reader(&mut reader).map_err(map_err)?; + + Ok(Arc::new(Self::new(return_dtype)) as Arc) + } +} + +impl FunctionOutputField for PythonGetOutput { + fn get_field( + &self, + _input_schema: &Schema, + _cntxt: Context, + fields: &[Field], + ) -> PolarsResult { + // Take the name of first field, just like [`GetOutput::map_field`]. + let name = fields[0].name(); + let return_dtype = match self.return_dtype { + Some(ref dtype) => dtype.clone(), + None => DataType::Unknown(Default::default()), + }; + Ok(Field::new(name.clone(), return_dtype)) + } + + #[cfg(feature = "serde")] + fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { + buf.extend_from_slice(MAGIC_BYTE_MARK); + ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap(); + Ok(()) + } +} + impl Expr { pub fn map_python(self, func: PythonUdfExpression, agg_list: bool) -> Expr { let (collect_groups, name) = if agg_list { @@ -241,16 +290,10 @@ impl Expr { let returns_scalar = func.returns_scalar; let return_dtype = func.output_type.clone(); - let output_type = GetOutput::map_field(move |fld| { - Ok(match return_dtype { - Some(ref dt) => Field::new(fld.name(), dt.clone()), - None => { - let mut fld = fld.clone(); - fld.coerce(DataType::Unknown(Default::default())); - fld - }, - }) - }); + + let output_field = PythonGetOutput::new(return_dtype); + let output_type = SpecialEq::new(Arc::new(output_field) as Arc); + let mut flags = FunctionFlags::default() | FunctionFlags::OPTIONAL_RE_ENTRANT; if returns_scalar { flags |= FunctionFlags::RETURNS_SCALAR; diff --git a/crates/polars-plan/src/dsl/selector.rs b/crates/polars-plan/src/dsl/selector.rs index b19781de1024..16e7d7b374e0 100644 --- a/crates/polars-plan/src/dsl/selector.rs +++ b/crates/polars-plan/src/dsl/selector.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use super::*; -#[derive(Clone, PartialEq, Hash)] +#[derive(Clone, PartialEq, Hash, Debug, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum Selector { Add(Box, Box), @@ -58,18 +58,18 @@ impl Sub for Selector { impl From<&str> for Selector { fn from(value: &str) -> Self { - Selector::new(col(value)) + Selector::new(col(PlSmallStr::from_str(value))) } } impl From for Selector { fn from(value: String) -> Self { - Selector::new(col(value.as_ref())) + Selector::new(col(PlSmallStr::from_string(value))) } } -impl From for Selector { - fn from(value: ColumnName) -> Self { +impl From for Selector { + fn from(value: PlSmallStr) -> Self { Selector::new(Expr::Column(value)) } } diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 3346c4413d68..d392d403d1b6 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -13,7 +13,7 @@ impl StringNameSpace { }), &[pat], false, - true, + Some(Default::default()), ) } @@ -28,7 +28,7 @@ impl StringNameSpace { }), &[pat], false, - true, + Some(Default::default()), ) } @@ -47,7 +47,7 @@ impl StringNameSpace { }), &[patterns], false, - false, + None, ) } @@ -71,7 +71,7 @@ impl StringNameSpace { }), &[patterns, replace_with], false, - false, + None, ) } @@ -96,7 +96,7 @@ impl StringNameSpace { }), &[patterns], false, - false, + None, ) } @@ -106,7 +106,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::EndsWith), &[sub], false, - true, + Some(Default::default()), ) } @@ -116,7 +116,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StartsWith), &[sub], false, - true, + Some(Default::default()), ) } @@ -152,7 +152,7 @@ impl StringNameSpace { StringFunction::Extract(group_index).into(), &[pat], false, - true, + Some(Default::default()), ) } @@ -161,6 +161,8 @@ impl StringNameSpace { pub fn extract_groups(self, pat: &str) -> PolarsResult { // regex will be compiled twice, because it doesn't support serde // and we need to compile it here to determine the output datatype + + use polars_utils::format_pl_smallstr; let reg = regex::Regex::new(pat)?; let names = reg .capture_names() @@ -168,22 +170,22 @@ impl StringNameSpace { .skip(1) .map(|(idx, opt_name)| { opt_name - .map(|name| name.to_string()) - .unwrap_or_else(|| format!("{idx}")) + .map(PlSmallStr::from_str) + .unwrap_or_else(|| format_pl_smallstr!("{idx}")) }) .collect::>(); let dtype = DataType::Struct( names .iter() - .map(|name| Field::new(name.as_str(), DataType::String)) + .map(|name| Field::new(name.clone(), DataType::String)) .collect(), ); Ok(self.0.map_private( StringFunction::ExtractGroups { dtype, - pat: pat.to_string(), + pat: pat.into(), } .into(), )) @@ -220,7 +222,7 @@ impl StringNameSpace { #[cfg(feature = "string_pad")] pub fn zfill(self, length: Expr) -> Expr { self.0 - .map_many_private(StringFunction::ZFill.into(), &[length], false, false) + .map_many_private(StringFunction::ZFill.into(), &[length], false, None) } /// Find the index of a literal substring within another string value. @@ -233,7 +235,7 @@ impl StringNameSpace { }), &[pat], false, - true, + Some(Default::default()), ) } @@ -247,14 +249,14 @@ impl StringNameSpace { }), &[pat], false, - true, + Some(Default::default()), ) } /// Extract each successive non-overlapping match in an individual string as an array pub fn extract_all(self, pat: Expr) -> Expr { self.0 - .map_many_private(StringFunction::ExtractAll.into(), &[pat], false, false) + .map_many_private(StringFunction::ExtractAll.into(), &[pat], false, None) } /// Count all successive non-overlapping regex matches. @@ -263,7 +265,7 @@ impl StringNameSpace { StringFunction::CountMatches(literal).into(), &[pat], false, - false, + None, ) } @@ -274,7 +276,7 @@ impl StringNameSpace { StringFunction::Strptime(dtype, options).into(), &[ambiguous], false, - false, + None, ) } @@ -333,7 +335,7 @@ impl StringNameSpace { self.0 .apply_private( StringFunction::ConcatVertical { - delimiter: delimiter.to_owned(), + delimiter: delimiter.into(), ignore_nulls, } .into(), @@ -348,13 +350,13 @@ impl StringNameSpace { /// Split the string by a substring. The resulting dtype is `List`. pub fn split(self, by: Expr) -> Expr { self.0 - .map_many_private(StringFunction::Split(false).into(), &[by], false, false) + .map_many_private(StringFunction::Split(false).into(), &[by], false, None) } /// Split the string by a substring and keep the substring. The resulting dtype is `List`. pub fn split_inclusive(self, by: Expr) -> Expr { self.0 - .map_many_private(StringFunction::Split(true).into(), &[by], false, false) + .map_many_private(StringFunction::Split(true).into(), &[by], false, None) } #[cfg(feature = "dtype-struct")] @@ -368,7 +370,7 @@ impl StringNameSpace { .into(), &[by], false, - false, + None, ) } @@ -380,7 +382,7 @@ impl StringNameSpace { StringFunction::SplitExact { n, inclusive: true }.into(), &[by], false, - false, + None, ) } @@ -389,7 +391,7 @@ impl StringNameSpace { /// keeps the remainder of the string intact. The resulting dtype is [`DataType::Struct`]. pub fn splitn(self, by: Expr, n: usize) -> Expr { self.0 - .map_many_private(StringFunction::SplitN(n).into(), &[by], false, false) + .map_many_private(StringFunction::SplitN(n).into(), &[by], false, None) } #[cfg(feature = "regex")] @@ -399,7 +401,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Replace { n: 1, literal }), &[pat, value], false, - true, + Some(Default::default()), ) } @@ -410,7 +412,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Replace { n, literal }), &[pat, value], false, - true, + Some(Default::default()), ) } @@ -421,7 +423,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Replace { n: -1, literal }), &[pat, value], false, - true, + Some(Default::default()), ) } @@ -432,7 +434,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Reverse), &[], false, - false, + None, ) } @@ -442,7 +444,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StripChars), &[matches], false, - false, + None, ) } @@ -452,7 +454,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StripCharsStart), &[matches], false, - false, + None, ) } @@ -462,7 +464,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StripCharsEnd), &[matches], false, - false, + None, ) } @@ -472,7 +474,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StripPrefix), &[prefix], false, - false, + None, ) } @@ -482,7 +484,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::StripSuffix), &[suffix], false, - false, + None, ) } @@ -512,7 +514,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::ToInteger(strict)), &[base], false, - false, + None, ) } @@ -547,7 +549,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Slice), &[offset, length], false, - false, + None, ) } @@ -557,7 +559,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Head), &[n], false, - false, + None, ) } @@ -567,7 +569,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::Tail), &[n], false, - false, + None, ) } @@ -587,7 +589,7 @@ impl StringNameSpace { FunctionExpr::StringExpr(StringFunction::JsonPathMatch), &[pat], false, - false, + None, ) } } diff --git a/crates/polars-plan/src/dsl/struct_.rs b/crates/polars-plan/src/dsl/struct_.rs index b5a1afafa698..adfa9da8bd11 100644 --- a/crates/polars-plan/src/dsl/struct_.rs +++ b/crates/polars-plan/src/dsl/struct_.rs @@ -18,16 +18,15 @@ impl StructNameSpace { /// Retrieve one or multiple of the fields of this [`StructChunked`] as a new Series. /// This expression also expands the `"*"` wildcard column. - pub fn field_by_names>(self, names: &[S]) -> Expr { - self.field_by_names_impl( - names - .iter() - .map(|name| ColumnName::from(name.as_ref())) - .collect(), - ) + pub fn field_by_names(self, names: I) -> Expr + where + I: IntoIterator, + S: Into, + { + self.field_by_names_impl(names.into_iter().map(|x| x.into()).collect()) } - fn field_by_names_impl(self, names: Arc<[ColumnName]>) -> Expr { + fn field_by_names_impl(self, names: Arc<[PlSmallStr]>) -> Expr { self.0 .map_private(FunctionExpr::StructExpr(StructFunction::MultipleFields( names, @@ -42,11 +41,11 @@ impl StructNameSpace { /// This expression also supports wildcard "*" and regex expansion. pub fn field_by_name(self, name: &str) -> Expr { if name == "*" || is_regex_projection(name) { - return self.field_by_names(&[name]); + return self.field_by_names([name]); } self.0 .map_private(FunctionExpr::StructExpr(StructFunction::FieldByName( - ColumnName::from(name), + name.into(), ))) .with_function_options(|mut options| { options.flags |= FunctionFlags::ALLOW_RENAME; @@ -55,10 +54,18 @@ impl StructNameSpace { } /// Rename the fields of the [`StructChunked`]. - pub fn rename_fields(self, names: Vec) -> Expr { + pub fn rename_fields(self, names: I) -> Expr + where + I: IntoIterator, + S: Into, + { + self._rename_fields_impl(names.into_iter().map(|x| x.into()).collect()) + } + + pub fn _rename_fields_impl(self, names: Arc<[PlSmallStr]>) -> Expr { self.0 .map_private(FunctionExpr::StructExpr(StructFunction::RenameFields( - Arc::from(names), + names, ))) } diff --git a/crates/polars-plan/src/dsl/udf.rs b/crates/polars-plan/src/dsl/udf.rs index 35f59bc78df1..fe01cab03ea2 100644 --- a/crates/polars-plan/src/dsl/udf.rs +++ b/crates/polars-plan/src/dsl/udf.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use arrow::legacy::error::{polars_bail, PolarsResult}; use polars_core::prelude::Field; use polars_core::schema::Schema; +use polars_utils::pl_str::PlSmallStr; use super::{Expr, GetOutput, SeriesUdf, SpecialEq}; use crate::prelude::{Context, FunctionOptions}; @@ -11,7 +12,7 @@ use crate::prelude::{Context, FunctionOptions}; #[derive(Clone)] pub struct UserDefinedFunction { /// name - pub name: String, + pub name: PlSmallStr, /// The function signature. pub input_fields: Vec, /// The function output type. @@ -36,13 +37,13 @@ impl std::fmt::Debug for UserDefinedFunction { impl UserDefinedFunction { /// Create a new UserDefinedFunction pub fn new( - name: &str, + name: PlSmallStr, input_fields: Vec, return_type: GetOutput, fun: impl SeriesUdf + 'static, ) -> Self { Self { - name: name.to_owned(), + name, input_fields, return_type, fun: SpecialEq::new(Arc::new(fun)), diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs index 9a3060cc361f..934f42e6109f 100644 --- a/crates/polars-plan/src/frame/opt_state.rs +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -3,7 +3,7 @@ use bitflags::bitflags; bitflags! { #[derive(Copy, Clone, Debug)] /// Allowed optimizations. - pub struct OptState: u32 { + pub struct OptFlags: u32 { /// Only read columns that are used later in the query. const PROJECTION_PUSHDOWN = 1; /// Apply predicates/filters as early as possible. @@ -18,11 +18,9 @@ bitflags! { const FILE_CACHING = 1 << 6; /// Pushdown slices/limits. const SLICE_PUSHDOWN = 1 << 7; - #[cfg(feature = "cse")] /// Run common-subplan-elimination. This elides duplicate plans and caches their /// outputs. const COMM_SUBPLAN_ELIM = 1 << 8; - #[cfg(feature = "cse")] /// Run common-subexpression-elimination. This elides duplicate expressions and caches their /// outputs. const COMM_SUBEXPR_ELIM = 1 << 9; @@ -38,7 +36,13 @@ bitflags! { } } -impl Default for OptState { +impl OptFlags { + pub fn schema_only() -> Self { + Self::TYPE_COERCION + } +} + +impl Default for OptFlags { fn default() -> Self { Self::from_bits_truncate(u32::MAX) & !Self::NEW_STREAMING & !Self::STREAMING & !Self::EAGER // will be toggled by a scan operation such as csv scan or parquet scan @@ -47,4 +51,4 @@ impl Default for OptState { } /// AllowedOptimizations -pub type AllowedOptimizations = OptState; +pub type AllowedOptimizations = OptFlags; diff --git a/crates/polars-plan/src/plans/aexpr/hash.rs b/crates/polars-plan/src/plans/aexpr/hash.rs index ebcea17fb978..6f8c97dc2a36 100644 --- a/crates/polars-plan/src/plans/aexpr/hash.rs +++ b/crates/polars-plan/src/plans/aexpr/hash.rs @@ -13,7 +13,6 @@ impl Hash for AExpr { match self { AExpr::Column(name) => name.hash(state), AExpr::Alias(_, name) => name.hash(state), - AExpr::Nth(v) => v.hash(state), AExpr::Literal(lv) => lv.hash(state), AExpr::Function { options, function, .. diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index a131dd8d087d..70c6335bcd1e 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -11,6 +11,8 @@ use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_time_units, try_get_supertype}; use polars_utils::arena::{Arena, Node}; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; pub use utils::*; @@ -19,6 +21,7 @@ use crate::plans::Context; use crate::prelude::*; #[derive(Clone, Debug, IntoStaticStr)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub enum IRAggExpr { Min { input: Node, @@ -125,10 +128,11 @@ impl From for GroupByMethod { /// IR expression node that is allocated in an [`Arena`][polars_utils::arena::Arena]. #[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub enum AExpr { Explode(Node), - Alias(Node, ColumnName), - Column(ColumnName), + Alias(Node, PlSmallStr), + Column(PlSmallStr), Literal(LiteralValue), BinaryExpr { left: Node, @@ -137,7 +141,7 @@ pub enum AExpr { }, Cast { expr: Node, - data_type: DataType, + dtype: DataType, options: CastOptions, }, Sort { @@ -186,21 +190,19 @@ pub enum AExpr { order_by: Option<(Node, SortOptions)>, options: WindowType, }, - #[default] - Wildcard, Slice { input: Node, offset: Node, length: Node, }, + #[default] Len, - Nth(i64), } impl AExpr { #[cfg(feature = "cse")] - pub(crate) fn col(name: &str) -> Self { - AExpr::Column(ColumnName::from(name)) + pub(crate) fn col(name: PlSmallStr) -> Self { + AExpr::Column(name) } /// Any expression that is sensitive to the number of elements in a group /// - Aggregations @@ -220,7 +222,6 @@ impl AExpr { | Len | Slice { .. } | Gather { .. } - | Nth(_) => true, Alias(_, _) | Explode(_) @@ -230,7 +231,6 @@ impl AExpr { // to determine if the whole expr. is group sensitive | BinaryExpr { .. } | Ternary { .. } - | Wildcard | Cast { .. } | Filter { .. } => false, } @@ -244,7 +244,7 @@ impl AExpr { arena: &Arena, ) -> PolarsResult { self.to_field(schema, ctxt, arena) - .map(|f| f.data_type().clone()) + .map(|f| f.dtype().clone()) } /// Push nodes at this level to a pre-allocated stack @@ -252,7 +252,7 @@ impl AExpr { use AExpr::*; match self { - Nth(_) | Column(_) | Literal(_) | Wildcard | Len => {}, + Column(_) | Literal(_) | Len => {}, Alias(e, _) => container.push_node(*e), BinaryExpr { left, op: _, right } => { // reverse order so that left is popped first @@ -334,7 +334,7 @@ impl AExpr { pub(crate) fn replace_inputs(mut self, inputs: &[Node]) -> Self { use AExpr::*; let input = match &mut self { - Column(_) | Literal(_) | Wildcard | Len | Nth(_) => return self, + Column(_) | Literal(_) | Len => return self, Alias(input, _) => input, Cast { expr, .. } => expr, Explode(input) => input, @@ -424,10 +424,7 @@ impl AExpr { } pub(crate) fn is_leaf(&self) -> bool { - matches!( - self, - AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len | AExpr::Nth(_) - ) + matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len) } } diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 07f2a94f6b34..0145776684f4 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -3,10 +3,15 @@ use recursive::recursive; use super::*; fn float_type(field: &mut Field) { - if (field.dtype.is_numeric() || field.dtype == DataType::Boolean) - && field.dtype != DataType::Float32 - { - field.coerce(DataType::Float64) + let should_coerce = match &field.dtype { + DataType::Float32 => false, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(..) => true, + DataType::Boolean => true, + dt => dt.is_numeric(), + }; + if should_coerce { + field.coerce(DataType::Float64); } } @@ -36,7 +41,7 @@ impl AExpr { let mut field = self.to_field_impl(schema, arena, &mut nested)?; if nested >= 1 { - field.coerce(field.data_type().clone().implode()); + field.coerce(field.dtype().clone().implode()); } Ok(field) } @@ -54,23 +59,28 @@ impl AExpr { match self { Len => { *nested = 0; - Ok(Field::new(LEN, IDX_DTYPE)) + Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) }, - Window { function, .. } => { + Window { + function, options, .. + } => { + if let WindowType::Over(mapping) = options { + *nested += matches!(mapping, WindowMapping::Join) as u8; + } let e = arena.get(*function); e.to_field_impl(schema, arena, nested) }, Explode(expr) => { let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; - if let List(inner) = field.data_type() { - Ok(Field::new(field.name(), *inner.clone())) + if let List(inner) = field.dtype() { + Ok(Field::new(field.name().clone(), *inner.clone())) } else { Ok(field) } }, Alias(expr, name) => Ok(Field::new( - name, + name.clone(), arena.get(*expr).to_field_impl(schema, arena, nested)?.dtype, )), Column(name) => schema @@ -80,7 +90,7 @@ impl AExpr { *nested = 0; Ok(match sv { LiteralValue::Series(s) => s.field().into_owned(), - _ => Field::new(sv.output_name(), sv.get_datatype()), + _ => Field::new(sv.output_name().clone(), sv.get_datatype()), }) }, BinaryExpr { left, right, op } => { @@ -100,9 +110,9 @@ impl AExpr { let out_field; let out_name = { out_field = arena.get(*left).to_field_impl(schema, arena, nested)?; - out_field.name().as_str() + out_field.name() }; - Field::new(out_name, Boolean) + Field::new(out_name.clone(), Boolean) }, Operator::TrueDivide => { return get_truediv_field(*left, *right, arena, schema, nested) @@ -138,7 +148,7 @@ impl AExpr { Sum(expr) => { *nested = nested.saturating_sub(1); let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; - let dt = match field.data_type() { + let dt = match field.dtype() { Boolean => Some(IDX_DTYPE), UInt8 | Int8 | Int16 | UInt16 => Some(Int64), _ => None, @@ -168,7 +178,7 @@ impl AExpr { }, Implode(expr) => { let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; - field.coerce(DataType::List(field.data_type().clone().into())); + field.coerce(DataType::List(field.dtype().clone().into())); Ok(field) }, Std(expr, _) => { @@ -209,11 +219,9 @@ impl AExpr { }, } }, - Cast { - expr, data_type, .. - } => { + Cast { expr, dtype, .. } => { let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; - Ok(Field::new(field.name(), data_type.clone())) + Ok(Field::new(field.name().clone(), dtype.clone())) }, Ternary { truthy, falsy, .. } => { let mut nested_truthy = *nested; @@ -231,10 +239,10 @@ impl AExpr { .get(*falsy) .to_field_impl(schema, arena, &mut nested_falsy)?; - let st = if let DataType::Null = *truthy.data_type() { - falsy.data_type().clone() + let st = if let DataType::Null = *truthy.dtype() { + falsy.dtype().clone() } else { - try_get_supertype(truthy.data_type(), falsy.data_type())? + try_get_supertype(truthy.dtype(), falsy.dtype())? }; *nested = std::cmp::max(nested_truthy, nested_falsy); @@ -269,12 +277,6 @@ impl AExpr { function.get_field(schema, Context::Default, &fields) }, Slice { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested), - Wildcard => { - polars_bail!(ComputeError: "wildcard column selection not supported at this point") - }, - Nth(n) => { - polars_bail!(ComputeError: "nth column selection not supported at this point (n={})", n) - }, } } } @@ -293,7 +295,7 @@ fn func_args_to_fields( .get(e.node()) .to_field_impl(schema, arena, nested) .map(|mut field| { - field.name = e.output_name().into(); + field.name = e.output_name().clone(); field }) }) @@ -335,7 +337,7 @@ fn get_arithmetic_field( | (Duration(_), Date) | (Date, Duration(_)) | (Duration(_), Time) - | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, + | (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?, (Datetime(tu, _), Date) | (Date, Datetime(tu, _)) => Duration(*tu), // T - T != T if T is a datetime / date (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)), @@ -364,7 +366,7 @@ fn get_arithmetic_field( | (Duration(_), Date) | (Date, Duration(_)) | (Duration(_), Time) - | (Time, Duration(_)) => try_get_supertype(left_field.data_type(), &right_type)?, + | (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?, (_, Datetime(_, _)) | (Datetime(_, _), _) | (_, Date) @@ -450,19 +452,19 @@ fn get_truediv_field( ) -> PolarsResult { let mut left_field = arena.get(left).to_field_impl(schema, arena, nested)?; use DataType::*; - let out_type = match left_field.data_type() { + let out_type = match left_field.dtype() { Float32 => Float32, dt if dt.is_numeric() => Float64, #[cfg(feature = "dtype-duration")] Duration(_) => match arena .get(right) .to_field_impl(schema, arena, nested)? - .data_type() + .dtype() { Duration(_) => Float64, dt if dt.is_numeric() => return Ok(left_field), dt => { - polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.data_type(), dt) + polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.dtype(), dt) }, }, #[cfg(feature = "dtype-datetime")] diff --git a/crates/polars-plan/src/plans/anonymous_scan.rs b/crates/polars-plan/src/plans/anonymous_scan.rs index d426b12f9af4..f4a641152091 100644 --- a/crates/polars-plan/src/plans/anonymous_scan.rs +++ b/crates/polars-plan/src/plans/anonymous_scan.rs @@ -8,7 +8,7 @@ use crate::dsl::Expr; pub struct AnonymousScanArgs { pub n_rows: Option, - pub with_columns: Option>, + pub with_columns: Option>, pub schema: SchemaRef, pub output_schema: Option, pub predicate: Option, diff --git a/crates/polars-plan/src/plans/builder_dsl.rs b/crates/polars-plan/src/plans/builder_dsl.rs index b9499111fb04..7efa55417509 100644 --- a/crates/polars-plan/src/plans/builder_dsl.rs +++ b/crates/polars-plan/src/plans/builder_dsl.rs @@ -1,5 +1,3 @@ -#[cfg(any(feature = "csv", feature = "ipc", feature = "parquet"))] -use std::path::PathBuf; use std::sync::{Arc, Mutex, RwLock}; use polars_core::prelude::*; @@ -60,7 +58,10 @@ impl DslBuilder { }; Ok(DslPlan::Scan { - paths: Arc::new(Mutex::new((Arc::new(vec![]), true))), + sources: Arc::new(Mutex::new(DslScanSources { + sources: ScanSources::Buffers(Arc::default()), + is_expanded: true, + })), file_info: Arc::new(RwLock::new(Some(file_info))), hive_parts: None, predicate: None, @@ -79,7 +80,7 @@ impl DslBuilder { #[cfg(feature = "parquet")] #[allow(clippy::too_many_arguments)] pub fn scan_parquet( - paths: Arc>, + sources: DslScanSources, n_rows: Option, cache: bool, parallel: polars_io::parquet::read::ParallelStrategy, @@ -90,10 +91,8 @@ impl DslBuilder { use_statistics: bool, hive_options: HiveOptions, glob: bool, - include_file_paths: Option>, + include_file_paths: Option, ) -> PolarsResult { - let paths = init_paths(paths); - let options = FileScanOptions { with_columns: None, cache, @@ -106,7 +105,7 @@ impl DslBuilder { include_file_paths, }; Ok(DslPlan::Scan { - paths, + sources: Arc::new(Mutex::new(sources)), file_info: Arc::new(RwLock::new(None)), hive_parts: None, predicate: None, @@ -127,7 +126,7 @@ impl DslBuilder { #[cfg(feature = "ipc")] #[allow(clippy::too_many_arguments)] pub fn scan_ipc( - paths: Arc>, + sources: DslScanSources, options: IpcScanOptions, n_rows: Option, cache: bool, @@ -135,12 +134,10 @@ impl DslBuilder { rechunk: bool, cloud_options: Option, hive_options: HiveOptions, - include_file_paths: Option>, + include_file_paths: Option, ) -> PolarsResult { - let paths = init_paths(paths); - Ok(DslPlan::Scan { - paths, + sources: Arc::new(Mutex::new(sources)), file_info: Arc::new(RwLock::new(None)), hive_parts: None, file_options: FileScanOptions { @@ -167,15 +164,13 @@ impl DslBuilder { #[allow(clippy::too_many_arguments)] #[cfg(feature = "csv")] pub fn scan_csv( - paths: Arc>, + sources: DslScanSources, read_options: CsvReadOptions, cache: bool, cloud_options: Option, glob: bool, - include_file_paths: Option>, + include_file_paths: Option, ) -> PolarsResult { - let paths = init_paths(paths); - // This gets partially moved by FileScanOptions let read_options_clone = read_options.clone(); @@ -195,7 +190,7 @@ impl DslBuilder { include_file_paths, }; Ok(DslPlan::Scan { - paths, + sources: Arc::new(Mutex::new(sources)), file_info: Arc::new(RwLock::new(None)), hive_parts: None, file_options: options, @@ -346,15 +341,19 @@ impl DslBuilder { .into() } - pub fn explode(self, columns: Vec) -> Self { + pub fn explode(self, columns: Vec, allow_empty: bool) -> Self { DslPlan::MapFunction { input: Arc::new(self.0), - function: DslFunction::Explode { columns }, + function: DslFunction::Explode { + columns, + allow_empty, + }, } .into() } - pub fn unpivot(self, args: UnpivotArgs) -> Self { + #[cfg(feature = "pivot")] + pub fn unpivot(self, args: UnpivotArgsDSL) -> Self { DslPlan::MapFunction { input: Arc::new(self.0), function: DslFunction::Unpivot { args }, @@ -362,18 +361,15 @@ impl DslBuilder { .into() } - pub fn row_index(self, name: &str, offset: Option) -> Self { + pub fn row_index(self, name: PlSmallStr, offset: Option) -> Self { DslPlan::MapFunction { input: Arc::new(self.0), - function: DslFunction::RowIndex { - name: ColumnName::from(name), - offset, - }, + function: DslFunction::RowIndex { name, offset }, } .into() } - pub fn distinct(self, options: DistinctOptions) -> Self { + pub fn distinct(self, options: DistinctOptionsDSL) -> Self { DslPlan::Distinct { input: Arc::new(self.0), options, @@ -402,6 +398,7 @@ impl DslBuilder { input_right: Arc::new(other), left_on, right_on, + predicates: Default::default(), options, } .into() @@ -424,12 +421,12 @@ impl DslBuilder { ) -> Self { DslPlan::MapFunction { input: Arc::new(self.0), - function: DslFunction::FunctionNode(FunctionNode::OpaquePython { + function: DslFunction::OpaquePython(OpaquePythonUdf { function, schema, - predicate_pd: optimizations.contains(OptState::PREDICATE_PUSHDOWN), - projection_pd: optimizations.contains(OptState::PROJECTION_PUSHDOWN), - streamable: optimizations.contains(OptState::STREAMING), + predicate_pd: optimizations.contains(OptFlags::PREDICATE_PUSHDOWN), + projection_pd: optimizations.contains(OptFlags::PROJECTION_PUSHDOWN), + streamable: optimizations.contains(OptFlags::STREAMING), validate_output, }), } @@ -441,7 +438,7 @@ impl DslBuilder { function: F, optimizations: AllowedOptimizations, schema: Option>, - name: &'static str, + name: PlSmallStr, ) -> Self where F: DataFrameUdf + 'static, @@ -450,21 +447,15 @@ impl DslBuilder { DslPlan::MapFunction { input: Arc::new(self.0), - function: DslFunction::FunctionNode(FunctionNode::Opaque { + function: DslFunction::FunctionIR(FunctionIR::Opaque { function, schema, - predicate_pd: optimizations.contains(OptState::PREDICATE_PUSHDOWN), - projection_pd: optimizations.contains(OptState::PROJECTION_PUSHDOWN), - streamable: optimizations.contains(OptState::STREAMING), + predicate_pd: optimizations.contains(OptFlags::PREDICATE_PUSHDOWN), + projection_pd: optimizations.contains(OptFlags::PROJECTION_PUSHDOWN), + streamable: optimizations.contains(OptFlags::STREAMING), fmt_str: name, }), } .into() } } - -/// Initialize paths as non-expanded. -#[cfg(any(feature = "csv", feature = "ipc", feature = "parquet"))] -fn init_paths(paths: Arc>) -> Arc>, bool)>> { - Arc::new(Mutex::new((paths, false))) -} diff --git a/crates/polars-plan/src/plans/builder_ir.rs b/crates/polars-plan/src/plans/builder_ir.rs index 1bab177f41b1..7eddfdfea5da 100644 --- a/crates/polars-plan/src/plans/builder_ir.rs +++ b/crates/polars-plan/src/plans/builder_ir.rs @@ -68,7 +68,7 @@ impl<'a> IRBuilder<'a> { let names = nodes .into_iter() .map(|node| match self.expr_arena.get(node.into()) { - AExpr::Column(name) => name.as_ref(), + AExpr::Column(name) => name, _ => unreachable!(), }); // This is a duplication of `project_simple` because we already borrow self.expr_arena :/ @@ -81,7 +81,7 @@ impl<'a> IRBuilder<'a> { .map(|name| { let dtype = input_schema.try_get(name)?; count += 1; - Ok(Field::new(name, dtype.clone())) + Ok(Field::new(name.clone(), dtype.clone())) }) .collect::>()?; @@ -96,10 +96,11 @@ impl<'a> IRBuilder<'a> { } } - pub(crate) fn project_simple<'c, I>(self, names: I) -> PolarsResult + pub(crate) fn project_simple(self, names: I) -> PolarsResult where - I: IntoIterator, + I: IntoIterator, I::IntoIter: ExactSizeIterator, + S: Into, { let names = names.into_iter(); // if len == 0, no projection has to be done. This is a select all operation. @@ -110,7 +111,8 @@ impl<'a> IRBuilder<'a> { let mut count = 0; let schema = names .map(|name| { - let dtype = input_schema.try_get(name)?; + let name: PlSmallStr = name.into(); + let dtype = input_schema.try_get(name.as_str())?; count += 1; Ok(Field::new(name, dtype.clone())) }) @@ -180,11 +182,8 @@ impl<'a> IRBuilder<'a> { .to_field(&schema, Context::Default, self.expr_arena) .unwrap(); - expr_irs.push(ExprIR::new( - node, - OutputName::ColumnLhs(ColumnName::from(field.name.as_ref())), - )); - new_schema.with_column(field.name().clone(), field.data_type().clone()); + expr_irs.push(ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))); + new_schema.with_column(field.name().clone(), field.dtype().clone()); } let lp = IR::HStack { @@ -197,10 +196,10 @@ impl<'a> IRBuilder<'a> { } // call this if the schema needs to be updated - pub(crate) fn explode(self, columns: Arc<[Arc]>) -> Self { + pub(crate) fn explode(self, columns: Arc<[PlSmallStr]>) -> Self { let lp = IR::MapFunction { input: self.root, - function: FunctionNode::Explode { + function: FunctionIR::Explode { columns, schema: Default::default(), }, @@ -297,10 +296,11 @@ impl<'a> IRBuilder<'a> { self.add_alp(lp) } - pub fn unpivot(self, args: Arc) -> Self { + #[cfg(feature = "pivot")] + pub fn unpivot(self, args: Arc) -> Self { let lp = IR::MapFunction { input: self.root, - function: FunctionNode::Unpivot { + function: FunctionIR::Unpivot { args, schema: Default::default(), }, @@ -308,10 +308,10 @@ impl<'a> IRBuilder<'a> { self.add_alp(lp) } - pub fn row_index(self, name: Arc, offset: Option) -> Self { + pub fn row_index(self, name: PlSmallStr, offset: Option) -> Self { let lp = IR::MapFunction { input: self.root, - function: FunctionNode::RowIndex { + function: FunctionIR::RowIndex { name, offset, schema: Default::default(), diff --git a/crates/polars-plan/src/plans/conversion/convert_utils.rs b/crates/polars-plan/src/plans/conversion/convert_utils.rs index 0267ef1f60a9..51e6940483ba 100644 --- a/crates/polars-plan/src/plans/conversion/convert_utils.rs +++ b/crates/polars-plan/src/plans/conversion/convert_utils.rs @@ -18,10 +18,10 @@ pub(super) fn convert_st_union( let mut exprs = vec![]; let input_schema = lp_arena.get(*input).schema(lp_arena); - let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map( + let to_cast = input_schema.iter().zip(schema.iter_values()).flat_map( |((left_name, left_type), st)| { if left_type != st { - Some(col(left_name.as_ref()).cast(st.clone())) + Some(col(left_name.clone()).cast(st.clone())) } else { None } @@ -30,7 +30,7 @@ pub(super) fn convert_st_union( exprs.extend(to_cast); if !exprs.is_empty() { - let expr = to_expr_irs(exprs, expr_arena); + let expr = to_expr_irs(exprs, expr_arena)?; let lp = IRBuilder::new(*input, expr_arena, lp_arena) .with_columns(expr, Default::default()) .build(); @@ -54,7 +54,7 @@ pub(super) fn convert_diagonal_concat( mut inputs: Vec, lp_arena: &mut Arena, expr_arena: &mut Arena, -) -> Vec { +) -> PolarsResult> { let schemas = nodes_to_schemas(&inputs, lp_arena); let upper_bound_width = schemas.iter().map(|sch| sch.len()).sum(); @@ -69,7 +69,7 @@ pub(super) fn convert_diagonal_concat( }); } if total_schema.is_empty() { - return inputs; + return Ok(inputs); } let mut has_empty = false; @@ -84,10 +84,10 @@ pub(super) fn convert_diagonal_concat( for (name, dtype) in total_schema.iter() { // If a name from Total Schema is not present - append if lf_schema.get_field(name).is_none() { - columns_to_add.push(NULL.lit().cast(dtype.clone()).alias(name)) + columns_to_add.push(NULL.lit().cast(dtype.clone()).alias(name.clone())) } } - let expr = to_expr_irs(columns_to_add, expr_arena); + let expr = to_expr_irs(columns_to_add, expr_arena)?; *node = IRBuilder::new(*node, expr_arena, lp_arena) // Add the missing columns .with_columns(expr, Default::default()) @@ -98,13 +98,13 @@ pub(super) fn convert_diagonal_concat( } if has_empty { - inputs + Ok(inputs .into_iter() .zip(schemas) .filter_map(|(input, schema)| if schema.is_empty() { None } else { Some(input) }) - .collect() + .collect()) } else { - inputs + Ok(inputs) } } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 3036fd6c3c49..5ab23be19a14 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -20,10 +20,11 @@ fn expand_expressions( exprs: Vec, lp_arena: &Arena, expr_arena: &mut Arena, + opt_flags: &mut OptFlags, ) -> PolarsResult> { let schema = lp_arena.get(input).schema(lp_arena); - let exprs = rewrite_projections(exprs, &schema, &[])?; - Ok(to_expr_irs(exprs, expr_arena)) + let exprs = rewrite_projections(exprs, &schema, &[], opt_flags)?; + to_expr_irs(exprs, expr_arena) } fn empty_df() -> IR { @@ -51,55 +52,69 @@ macro_rules! failed_here { format!("'{}' failed", stringify!($($t)*)).into() } } +pub(super) use {failed_here, failed_input, failed_input_args}; pub fn to_alp( lp: DslPlan, expr_arena: &mut Arena, lp_arena: &mut Arena, - simplify_expr: bool, - type_coercion: bool, + // Only `SIMPLIFY_EXPR` and `TYPE_COERCION` are respected. + opt_flags: &mut OptFlags, ) -> PolarsResult { - let mut convert = ConversionOptimizer::new(simplify_expr, type_coercion); - to_alp_impl(lp, expr_arena, lp_arena, &mut convert) + let conversion_optimizer = ConversionOptimizer::new( + opt_flags.contains(OptFlags::SIMPLIFY_EXPR), + opt_flags.contains(OptFlags::TYPE_COERCION), + ); + + let mut ctxt = DslConversionContext { + expr_arena, + lp_arena, + conversion_optimizer, + opt_flags, + }; + + to_alp_impl(lp, &mut ctxt) +} + +pub(super) struct DslConversionContext<'a> { + pub(super) expr_arena: &'a mut Arena, + pub(super) lp_arena: &'a mut Arena, + pub(super) conversion_optimizer: ConversionOptimizer, + pub(super) opt_flags: &'a mut OptFlags, +} + +pub(super) fn run_conversion( + lp: IR, + ctxt: &mut DslConversionContext, + name: &str, +) -> PolarsResult { + let lp_node = ctxt.lp_arena.add(lp); + ctxt.conversion_optimizer + .coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node) + .map_err(|e| e.context(format!("'{name}' failed").into()))?; + + Ok(lp_node) } /// converts LogicalPlan to IR /// it adds expressions & lps to the respective arenas as it traverses the plan /// finally it returns the top node of the logical plan #[recursive] -pub fn to_alp_impl( - lp: DslPlan, - expr_arena: &mut Arena, - lp_arena: &mut Arena, - convert: &mut ConversionOptimizer, -) -> PolarsResult { +pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult { let owned = Arc::unwrap_or_clone; - fn run_conversion( - lp: IR, - lp_arena: &mut Arena, - expr_arena: &mut Arena, - convert: &mut ConversionOptimizer, - name: &str, - ) -> PolarsResult { - let lp_node = lp_arena.add(lp); - convert - .coerce_types(expr_arena, lp_arena, lp_node) - .map_err(|e| e.context(format!("'{name}' failed").into()))?; - - Ok(lp_node) - } - let v = match lp { DslPlan::Scan { - paths, + sources, file_info, hive_parts, predicate, mut file_options, mut scan_type, } => { - let paths = expand_scan_paths(paths, &mut scan_type, &mut file_options)?; + let mut sources_lock = sources.lock().unwrap(); + sources_lock.expand_paths(&mut scan_type, &mut file_options)?; + let sources = sources_lock.sources.clone(); let file_info_read = file_info.read().unwrap(); @@ -125,9 +140,12 @@ pub fn to_alp_impl( metadata, .. } => { - let (file_info, md) = - scans::parquet_file_info(&paths, &file_options, cloud_options.as_ref()) - .map_err(|e| e.context(failed_here!(parquet scan)))?; + let (file_info, md) = scans::parquet_file_info( + &sources, + &file_options, + cloud_options.as_ref(), + ) + .map_err(|e| e.context(failed_here!(parquet scan)))?; *metadata = md; file_info }, @@ -138,7 +156,7 @@ pub fn to_alp_impl( .. } => { let (file_info, md) = - scans::ipc_file_info(&paths, &file_options, cloud_options.as_ref()) + scans::ipc_file_info(&sources, &file_options, cloud_options.as_ref()) .map_err(|e| e.context(failed_here!(ipc scan)))?; *metadata = Some(md); file_info @@ -147,16 +165,19 @@ pub fn to_alp_impl( FileScan::Csv { options, cloud_options, - } => { - scans::csv_file_info(&paths, &file_options, options, cloud_options.as_ref()) - .map_err(|e| e.context(failed_here!(csv scan)))? - }, + } => scans::csv_file_info( + &sources, + &file_options, + options, + cloud_options.as_ref(), + ) + .map_err(|e| e.context(failed_here!(csv scan)))?, #[cfg(feature = "json")] FileScan::NDJson { options, cloud_options, } => scans::ndjson_file_info( - &paths, + &sources, &file_options, options, cloud_options.as_ref(), @@ -172,16 +193,20 @@ pub fn to_alp_impl( } else if file_options.hive_options.enabled.unwrap_or(false) && resolved_file_info.reader_schema.is_some() { + let paths = sources + .as_paths() + .ok_or_else(|| polars_err!(nyi = "Hive-partitioning of in-memory buffers"))?; + #[allow(unused_assignments)] let mut owned = None; hive_partitions_from_paths( - paths.as_ref(), + paths, file_options.hive_options.hive_start_idx, file_options.hive_options.schema.clone(), match resolved_file_info.reader_schema.as_ref().unwrap() { Either::Left(v) => { - owned = Some(Schema::from(v)); + owned = Some(Schema::from_arrow_schema(v.as_ref())); owned.as_ref().unwrap() }, Either::Right(v) => v.as_ref(), @@ -225,7 +250,7 @@ pub fn to_alp_impl( schema.insert_at_index( schema.len(), - file_path_col.as_ref().into(), + file_path_col.clone(), DataType::String, )?; } @@ -245,16 +270,18 @@ pub fn to_alp_impl( if let Some(row_index) = &file_options.row_index { let schema = Arc::make_mut(&mut resolved_file_info.schema); *schema = schema - .new_inserting_at_index(0, row_index.name.as_ref().into(), IDX_DTYPE) + .new_inserting_at_index(0, row_index.name.clone(), IDX_DTYPE) .unwrap(); } IR::Scan { - paths, + sources, file_info: resolved_file_info, hive_parts, output_schema: None, - predicate: predicate.map(|expr| to_expr_ir(expr, expr_arena)), + predicate: predicate + .map(|expr| to_expr_ir(expr, ctxt.expr_arena)) + .transpose()?, scan_type, file_options, } @@ -264,16 +291,17 @@ pub fn to_alp_impl( DslPlan::Union { inputs, args } => { let mut inputs = inputs .into_iter() - .map(|lp| to_alp_impl(lp, expr_arena, lp_arena, convert)) + .map(|lp| to_alp_impl(lp, ctxt)) .collect::>>() .map_err(|e| e.context(failed_input!(vertical concat)))?; if args.diagonal { - inputs = convert_utils::convert_diagonal_concat(inputs, lp_arena, expr_arena); + inputs = + convert_utils::convert_diagonal_concat(inputs, ctxt.lp_arena, ctxt.expr_arena)?; } if args.to_supertypes { - convert_utils::convert_st_union(&mut inputs, lp_arena, expr_arena) + convert_utils::convert_st_union(&mut inputs, ctxt.lp_arena, ctxt.expr_arena) .map_err(|e| e.context(failed_input!(vertical concat)))?; } let options = args.into(); @@ -282,11 +310,11 @@ pub fn to_alp_impl( DslPlan::HConcat { inputs, options } => { let inputs = inputs .into_iter() - .map(|lp| to_alp_impl(lp, expr_arena, lp_arena, convert)) + .map(|lp| to_alp_impl(lp, ctxt)) .collect::>>() .map_err(|e| e.context(failed_input!(horizontal concat)))?; - let schema = convert_utils::h_concat_schema(&inputs, lp_arena)?; + let schema = convert_utils::h_concat_schema(&inputs, ctxt.lp_arena)?; IR::HConcat { inputs, @@ -295,14 +323,14 @@ pub fn to_alp_impl( } }, DslPlan::Filter { input, predicate } => { - let mut input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(filter)))?; - let predicate = expand_filter(predicate, input, lp_arena) + let mut input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(filter)))?; + let predicate = expand_filter(predicate, input, ctxt.lp_arena, ctxt.opt_flags) .map_err(|e| e.context(failed_here!(filter)))?; - let predicate_ae = to_expr_ir(predicate.clone(), expr_arena); + let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; - return if is_streamable(predicate_ae.node(), expr_arena, Context::Default) { + return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Context::Default) { // Split expression that are ANDed into multiple Filter nodes as the optimizer can then // push them down independently. Especially if they refer columns from different tables // this will be more performant. @@ -327,24 +355,26 @@ pub fn to_alp_impl( } for predicate in predicates { - let predicate = to_expr_ir(predicate, expr_arena); - convert.push_scratch(predicate.node(), expr_arena); + let predicate = to_expr_ir(predicate, ctxt.expr_arena)?; + ctxt.conversion_optimizer + .push_scratch(predicate.node(), ctxt.expr_arena); let lp = IR::Filter { input, predicate }; - input = run_conversion(lp, lp_arena, expr_arena, convert, "filter")?; + input = run_conversion(lp, ctxt, "filter")?; } Ok(input) } else { - convert.push_scratch(predicate_ae.node(), expr_arena); + ctxt.conversion_optimizer + .push_scratch(predicate_ae.node(), ctxt.expr_arena); let lp = IR::Filter { input, predicate: predicate_ae, }; - run_conversion(lp, lp_arena, expr_arena, convert, "filter") + run_conversion(lp, ctxt, "filter") }; }, DslPlan::Slice { input, offset, len } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(slice)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(slice)))?; IR::Slice { input, offset, len } }, DslPlan::DataFrameScan { @@ -356,26 +386,29 @@ pub fn to_alp_impl( df, schema, output_schema, - filter: selection.map(|expr| to_expr_ir(expr, expr_arena)), + filter: selection + .map(|expr| to_expr_ir(expr, ctxt.expr_arena)) + .transpose()?, }, DslPlan::Select { expr, input, options, } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(select)))?; - let schema = lp_arena.get(input).schema(lp_arena); - let (exprs, schema) = - prepare_projection(expr, &schema).map_err(|e| e.context(failed_here!(select)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(select)))?; + let schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); + let (exprs, schema) = prepare_projection(expr, &schema, ctxt.opt_flags) + .map_err(|e| e.context(failed_here!(select)))?; if exprs.is_empty() { - lp_arena.replace(input, empty_df()); + ctxt.lp_arena.replace(input, empty_df()); } let schema = Arc::new(schema); - let eirs = to_expr_irs(exprs, expr_arena); - convert.fill_scratch(&eirs, expr_arena); + let eirs = to_expr_irs(exprs, ctxt.expr_arena)?; + ctxt.conversion_optimizer + .fill_scratch(&eirs, ctxt.expr_arena); let lp = IR::Select { expr: eirs, @@ -384,7 +417,7 @@ pub fn to_alp_impl( options, }; - return run_conversion(lp, lp_arena, expr_arena, convert, "select"); + return run_conversion(lp, ctxt, "select"); }, DslPlan::Sort { input, @@ -412,8 +445,8 @@ pub fn to_alp_impl( ComputeError: "the length of `nulls_last` ({}) does not match the length of `by` ({})", n_nulls_last, by_column.len() ); - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(sort)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(sort)))?; let mut expanded_cols = Vec::new(); let mut nulls_last = Vec::new(); @@ -429,8 +462,14 @@ pub fn to_alp_impl( .cycle() .zip(sort_options.descending.iter().cycle()), ) { - let exprs = expand_expressions(input, vec![c], lp_arena, expr_arena) - .map_err(|e| e.context(failed_here!(sort)))?; + let exprs = expand_expressions( + input, + vec![c], + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(sort)))?; nulls_last.extend(std::iter::repeat(n).take(exprs.len())); descending.extend(std::iter::repeat(d).take(exprs.len())); @@ -439,7 +478,8 @@ pub fn to_alp_impl( sort_options.nulls_last = nulls_last; sort_options.descending = descending; - convert.fill_scratch(&expanded_cols, expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&expanded_cols, ctxt.expr_arena); let by_column = expanded_cols; let lp = IR::Sort { @@ -449,15 +489,15 @@ pub fn to_alp_impl( sort_options, }; - return run_conversion(lp, lp_arena, expr_arena, convert, "sort"); + return run_conversion(lp, ctxt, "sort"); }, DslPlan::Cache { input, id, cache_hits, } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(cache)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(cache)))?; IR::Cache { input, id, @@ -472,12 +512,19 @@ pub fn to_alp_impl( maintain_order, options, } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(group_by)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(group_by)))?; - let (keys, aggs, schema) = - resolve_group_by(input, keys, aggs, &options, lp_arena, expr_arena) - .map_err(|e| e.context(failed_here!(group_by)))?; + let (keys, aggs, schema) = resolve_group_by( + input, + keys, + aggs, + &options, + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(group_by)))?; let (apply, schema) = if let Some((apply, schema)) = apply { (Some(apply), schema) @@ -485,8 +532,10 @@ pub fn to_alp_impl( (None, schema) }; - convert.fill_scratch(&keys, expr_arena); - convert.fill_scratch(&aggs, expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&keys, ctxt.expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&aggs, ctxt.expr_arena); let lp = IR::GroupBy { input, @@ -498,142 +547,113 @@ pub fn to_alp_impl( options, }; - return run_conversion(lp, lp_arena, expr_arena, convert, "group_by"); + return run_conversion(lp, ctxt, "group_by"); }, DslPlan::Join { input_left, input_right, left_on, right_on, - mut options, + predicates, + options, } => { - if matches!(options.args.how, JoinType::Cross) { - polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); - } else { - let mut turn_off_coalesce = false; - for e in left_on.iter().chain(right_on.iter()) { - if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { - polars_bail!( - ComputeError: - "'alias' is not allowed in a join key, use 'with_columns' first", - ) - } - // Any expression that is not a simple column expression will turn of coalescing. - turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_))); - } - if turn_off_coalesce { - let options = Arc::make_mut(&mut options); - if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) { - polars_warn!("coalescing join requested but not all join keys are column references, turning off key coalescing"); - } - options.args.coalesce = JoinCoalesce::KeepColumns; - } - - options.args.validation.is_valid_join(&options.args.how)?; - - polars_ensure!( - left_on.len() == right_on.len(), - ComputeError: - format!( - "the number of columns given as join key (left: {}, right:{}) should be equal", - left_on.len(), - right_on.len() - ) - ); - } - - let input_left = to_alp_impl(owned(input_left), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(join left)))?; - let input_right = to_alp_impl(owned(input_right), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(join, right)))?; - - let schema_left = lp_arena.get(input_left).schema(lp_arena); - let schema_right = lp_arena.get(input_right).schema(lp_arena); - - let schema = - det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options) - .map_err(|e| e.context(failed_here!(join schema resolving)))?; - - let left_on = to_expr_irs_ignore_alias(left_on, expr_arena); - let right_on = to_expr_irs_ignore_alias(right_on, expr_arena); - let mut joined_on = PlHashSet::new(); - for (l, r) in left_on.iter().zip(right_on.iter()) { - polars_ensure!( - joined_on.insert((l.output_name(), r.output_name())), - InvalidOperation: "joining with repeated key names; already joined on {} and {}", - l.output_name(), - r.output_name() - ) - } - drop(joined_on); - - convert.fill_scratch(&left_on, expr_arena); - convert.fill_scratch(&right_on, expr_arena); - - // Every expression must be elementwise so that we are - // guaranteed the keys for a join are all the same length. - let all_elementwise = - |aexprs: &[ExprIR]| all_streamable(aexprs, &*expr_arena, Context::Default); - polars_ensure!( - all_elementwise(&left_on) && all_elementwise(&right_on), - InvalidOperation: "All join key expressions must be elementwise." - ); - let lp = IR::Join { + return join::resolve_join( input_left, input_right, - schema, left_on, right_on, + predicates, options, - }; - return run_conversion(lp, lp_arena, expr_arena, convert, "join"); + ctxt, + ) }, DslPlan::HStack { input, exprs, options, } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) + let input = to_alp_impl(owned(input), ctxt) .map_err(|e| e.context(failed_input!(with_columns)))?; - let (exprs, schema) = resolve_with_columns(exprs, input, lp_arena, expr_arena) - .map_err(|e| e.context(failed_here!(with_columns)))?; + let (exprs, schema) = + resolve_with_columns(exprs, input, ctxt.lp_arena, ctxt.expr_arena, ctxt.opt_flags) + .map_err(|e| e.context(failed_here!(with_columns)))?; - convert.fill_scratch(&exprs, expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&exprs, ctxt.expr_arena); let lp = IR::HStack { input, exprs, schema, options, }; - return run_conversion(lp, lp_arena, expr_arena, convert, "with_columns"); + return run_conversion(lp, ctxt, "with_columns"); }, DslPlan::Distinct { input, options } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(unique)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(unique)))?; + let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); + + let subset = options + .subset + .map(|s| expand_selectors(s, input_schema.as_ref(), &[])) + .transpose()?; + let options = DistinctOptionsIR { + subset, + maintain_order: options.maintain_order, + keep_strategy: options.keep_strategy, + slice: None, + }; + IR::Distinct { input, options } }, DslPlan::MapFunction { input, function } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert).map_err(|e| { + let input = to_alp_impl(owned(input), ctxt).map_err(|e| { e.context(failed_input_args!(format!("{}", function).to_lowercase())) })?; - let input_schema = lp_arena.get(input).schema(lp_arena); + let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena); match function { + DslFunction::Explode { + columns, + allow_empty, + } => { + let columns = expand_selectors(columns, &input_schema, &[])?; + validate_columns_in_input(&columns, &input_schema, "explode")?; + polars_ensure!(!columns.is_empty() || allow_empty, InvalidOperation: "no columns provided in explode"); + if columns.is_empty() { + return Ok(input); + } + let function = FunctionIR::Explode { + columns, + schema: Default::default(), + }; + let ir = IR::MapFunction { input, function }; + return Ok(ctxt.lp_arena.add(ir)); + }, DslFunction::FillNan(fill_value) => { let exprs = input_schema .iter() .filter_map(|(name, dtype)| match dtype { - DataType::Float32 | DataType::Float64 => { - Some(col(name).fill_nan(fill_value.clone()).alias(name)) - }, + DataType::Float32 | DataType::Float64 => Some( + col(name.clone()) + .fill_nan(fill_value.clone()) + .alias(name.clone()), + ), _ => None, }) .collect::>(); - let (exprs, schema) = resolve_with_columns(exprs, input, lp_arena, expr_arena) - .map_err(|e| e.context(failed_here!(fill_nan)))?; + let (exprs, schema) = resolve_with_columns( + exprs, + input, + ctxt.lp_arena, + ctxt.expr_arena, + ctxt.opt_flags, + ) + .map_err(|e| e.context(failed_here!(fill_nan)))?; - convert.fill_scratch(&exprs, expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&exprs, ctxt.expr_arena); let lp = IR::HStack { input, @@ -644,7 +664,7 @@ pub fn to_alp_impl( ..Default::default() }, }; - return run_conversion(lp, lp_arena, expr_arena, convert, "fill_nan"); + return run_conversion(lp, ctxt, "fill_nan"); }, DslFunction::Drop(DropFunction { to_drop, strict }) => { let to_drop = expand_selectors(to_drop, &input_schema, &[])?; @@ -669,7 +689,7 @@ pub fn to_alp_impl( } if output_schema.is_empty() { - lp_arena.replace(input, empty_df()); + ctxt.lp_arena.replace(input, empty_df()); } IR::SimpleProjection { @@ -681,22 +701,22 @@ pub fn to_alp_impl( let exprs = match sf { StatsFunction::Var { ddof } => stats_helper( |dt| dt.is_numeric() || dt.is_bool(), - |name| col(name).var(ddof), + |name| col(name.clone()).var(ddof), &input_schema, ), StatsFunction::Std { ddof } => stats_helper( |dt| dt.is_numeric() || dt.is_bool(), - |name| col(name).std(ddof), + |name| col(name.clone()).std(ddof), &input_schema, ), StatsFunction::Quantile { quantile, interpol } => stats_helper( |dt| dt.is_numeric(), - |name| col(name).quantile(quantile.clone(), interpol), + |name| col(name.clone()).quantile(quantile.clone(), interpol), &input_schema, ), StatsFunction::Mean => stats_helper( |dt| dt.is_numeric() || dt.is_temporal() || dt == &DataType::Boolean, - |name| col(name).mean(), + |name| col(name.clone()).mean(), &input_schema, ), StatsFunction::Sum => stats_helper( @@ -705,18 +725,22 @@ pub fn to_alp_impl( || dt.is_decimal() || matches!(dt, DataType::Boolean | DataType::Duration(_)) }, - |name| col(name).sum(), + |name| col(name.clone()).sum(), + &input_schema, + ), + StatsFunction::Min => stats_helper( + |dt| dt.is_ord(), + |name| col(name.clone()).min(), + &input_schema, + ), + StatsFunction::Max => stats_helper( + |dt| dt.is_ord(), + |name| col(name.clone()).max(), &input_schema, ), - StatsFunction::Min => { - stats_helper(|dt| dt.is_ord(), |name| col(name).min(), &input_schema) - }, - StatsFunction::Max => { - stats_helper(|dt| dt.is_ord(), |name| col(name).max(), &input_schema) - }, StatsFunction::Median => stats_helper( |dt| dt.is_numeric() || dt.is_temporal() || dt == &DataType::Boolean, - |name| col(name).median(), + |name| col(name.clone()).median(), &input_schema, ), }; @@ -725,9 +749,10 @@ pub fn to_alp_impl( &input_schema, Context::Default, )?); - let eirs = to_expr_irs(exprs, expr_arena); + let eirs = to_expr_irs(exprs, ctxt.expr_arena)?; - convert.fill_scratch(&eirs, expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&eirs, ctxt.expr_arena); let lp = IR::Select { input, @@ -738,26 +763,26 @@ pub fn to_alp_impl( ..Default::default() }, }; - return run_conversion(lp, lp_arena, expr_arena, convert, "stats"); + return run_conversion(lp, ctxt, "stats"); }, _ => { - let function = function.into_function_node(&input_schema)?; + let function = function.into_function_ir(&input_schema)?; IR::MapFunction { input, function } }, } }, DslPlan::ExtContext { input, contexts } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) + let input = to_alp_impl(owned(input), ctxt) .map_err(|e| e.context(failed_input!(with_context)))?; let contexts = contexts .into_iter() - .map(|lp| to_alp_impl(lp, expr_arena, lp_arena, convert)) + .map(|lp| to_alp_impl(lp, ctxt)) .collect::>>() .map_err(|e| e.context(failed_here!(with_context)))?; - let mut schema = (**lp_arena.get(input).schema(lp_arena)).clone(); + let mut schema = (**ctxt.lp_arena.get(input).schema(ctxt.lp_arena)).clone(); for input in &contexts { - let other_schema = lp_arena.get(*input).schema(lp_arena); + let other_schema = ctxt.lp_arena.get(*input).schema(ctxt.lp_arena); for fld in other_schema.iter_fields() { if schema.get(fld.name()).is_none() { schema.with_column(fld.name, fld.dtype); @@ -772,62 +797,66 @@ pub fn to_alp_impl( } }, DslPlan::Sink { input, payload } => { - let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert) - .map_err(|e| e.context(failed_input!(sink)))?; + let input = + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(sink)))?; IR::Sink { input, payload } }, DslPlan::IR { node, dsl, version } => { - return if let (true, Some(node)) = (version == lp_arena.version(), node) { - Ok(node) + return if node.is_some() + && version == ctxt.lp_arena.version() + && ctxt.conversion_optimizer.used_arenas.insert(version) + { + Ok(node.unwrap()) } else { - to_alp_impl(owned(dsl), expr_arena, lp_arena, convert) + to_alp_impl(owned(dsl), ctxt) } }, }; - Ok(lp_arena.add(v)) + Ok(ctxt.lp_arena.add(v)) } -/// Expand scan paths if they were not already expanded. -#[allow(unused_variables)] -fn expand_scan_paths( - paths: Arc>, bool)>>, - scan_type: &mut FileScan, - file_options: &mut FileScanOptions, -) -> PolarsResult>> { - #[allow(unused_mut)] - let mut lock = paths.lock().unwrap(); +impl DslScanSources { + /// Expand scan paths if they were not already expanded. + pub fn expand_paths( + &mut self, + scan_type: &mut FileScan, + file_options: &mut FileScanOptions, + ) -> PolarsResult<()> { + if self.is_expanded { + return Ok(()); + } - // Return if paths are already expanded - if lock.1 { - return Ok(lock.0.clone()); - } + let ScanSources::Paths(paths) = &self.sources else { + self.is_expanded = true; + return Ok(()); + }; - { - let paths_expanded = match &scan_type { + let expanded_sources = match &scan_type { #[cfg(feature = "parquet")] FileScan::Parquet { cloud_options, .. } => { - expand_scan_paths_with_hive_update(&lock.0, file_options, cloud_options)? + expand_scan_paths_with_hive_update(paths, file_options, cloud_options)? }, #[cfg(feature = "ipc")] FileScan::Ipc { cloud_options, .. } => { - expand_scan_paths_with_hive_update(&lock.0, file_options, cloud_options)? + expand_scan_paths_with_hive_update(paths, file_options, cloud_options)? }, #[cfg(feature = "csv")] FileScan::Csv { cloud_options, .. } => { - expand_paths(&lock.0, file_options.glob, cloud_options.as_ref())? + expand_paths(paths, file_options.glob, cloud_options.as_ref())? }, #[cfg(feature = "json")] FileScan::NDJson { cloud_options, .. } => { - expand_paths(&lock.0, file_options.glob, cloud_options.as_ref())? + expand_paths(paths, file_options.glob, cloud_options.as_ref())? }, FileScan::Anonymous { .. } => unreachable!(), // Invariant: Anonymous scans are already expanded. }; #[allow(unreachable_code)] { - *lock = (paths_expanded, true); + self.sources = ScanSources::Paths(expanded_sources); + self.is_expanded = true; - Ok(lock.0.clone()) + Ok(()) } } } @@ -838,7 +867,7 @@ fn expand_scan_paths_with_hive_update( paths: &[PathBuf], file_options: &mut FileScanOptions, cloud_options: &Option, -) -> PolarsResult>> { +) -> PolarsResult> { let hive_enabled = file_options.hive_options.enabled; let (expanded_paths, hive_start_idx) = expand_paths_hive( paths, @@ -855,7 +884,12 @@ fn expand_scan_paths_with_hive_update( Ok(expanded_paths) } -fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena) -> PolarsResult { +fn expand_filter( + predicate: Expr, + input: Node, + lp_arena: &Arena, + opt_flags: &mut OptFlags, +) -> PolarsResult { let schema = lp_arena.get(input).schema(lp_arena); let predicate = if has_expr(&predicate, |e| match e { Expr::Column(name) => is_regex_projection(name), @@ -868,7 +902,7 @@ fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena) -> PolarsRe | Expr::Nth(_) => true, _ => false, }) { - let mut rewritten = rewrite_projections(vec![predicate], &schema, &[])?; + let mut rewritten = rewrite_projections(vec![predicate], &schema, &[], opt_flags)?; match rewritten.len() { 1 => { // all good @@ -915,10 +949,11 @@ fn resolve_with_columns( input: Node, lp_arena: &Arena, expr_arena: &mut Arena, + opt_flags: &mut OptFlags, ) -> PolarsResult<(Vec, SchemaRef)> { let schema = lp_arena.get(input).schema(lp_arena); let mut new_schema = (**schema).clone(); - let (exprs, _) = prepare_projection(exprs, &schema)?; + let (exprs, _) = prepare_projection(exprs, &schema, opt_flags)?; let mut output_names = PlHashSet::with_capacity(exprs.len()); let mut arena = Arena::with_capacity(8); @@ -937,11 +972,11 @@ fn resolve_with_columns( ); polars_bail!(ComputeError: msg) } - new_schema.with_column(field.name().clone(), field.data_type().clone()); + new_schema.with_column(field.name().clone(), field.dtype().clone()); arena.clear(); } - let eirs = to_expr_irs(exprs, expr_arena); + let eirs = to_expr_irs(exprs, expr_arena)?; Ok((eirs, Arc::new(new_schema))) } @@ -952,10 +987,11 @@ fn resolve_group_by( _options: &GroupbyOptions, lp_arena: &Arena, expr_arena: &mut Arena, + opt_flags: &mut OptFlags, ) -> PolarsResult<(Vec, Vec, SchemaRef)> { let current_schema = lp_arena.get(input).schema(lp_arena); let current_schema = current_schema.as_ref(); - let mut keys = rewrite_projections(keys, current_schema, &[])?; + let mut keys = rewrite_projections(keys, current_schema, &[], opt_flags)?; // Initialize schema from keys let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?; @@ -967,16 +1003,16 @@ fn resolve_group_by( #[cfg(feature = "dynamic_group_by")] { if let Some(options) = _options.rolling.as_ref() { - let name = &options.index_column; - let dtype = current_schema.try_get(name)?; - keys.push(col(name)); + let name = options.index_column.clone(); + let dtype = current_schema.try_get(name.as_str())?; + keys.push(col(name.clone())); pop_keys = true; schema.with_column(name.clone(), dtype.clone()); } else if let Some(options) = _options.dynamic.as_ref() { - let name = &options.index_column; - keys.push(col(name)); + let name = options.index_column.clone(); + keys.push(col(name.clone())); pop_keys = true; - let dtype = current_schema.try_get(name)?; + let dtype = current_schema.try_get(name.as_str())?; if options.include_boundaries { schema.with_column("_lower_boundary".into(), dtype.clone()); schema.with_column("_upper_boundary".into(), dtype.clone()); @@ -986,7 +1022,7 @@ fn resolve_group_by( } let keys_index_len = schema.len(); - let aggs = rewrite_projections(aggs, current_schema, &keys)?; + let aggs = rewrite_projections(aggs, current_schema, &keys, opt_flags)?; if pop_keys { let _ = keys.pop(); } @@ -1003,15 +1039,15 @@ fn resolve_group_by( polars_ensure!(names.insert(name.clone()), duplicate = name) } } - let aggs = to_expr_irs(aggs, expr_arena); - let keys = keys.convert(|e| to_expr_ir(e.clone(), expr_arena)); + let aggs = to_expr_irs(aggs, expr_arena)?; + let keys = to_expr_irs(keys, expr_arena)?; Ok((keys, aggs, Arc::new(schema))) } fn stats_helper(condition: F, expr: E, schema: &Schema) -> Vec where F: Fn(&DataType) -> bool, - E: Fn(&str) -> Expr, + E: Fn(&PlSmallStr) -> Expr, { schema .iter() @@ -1019,7 +1055,7 @@ where if condition(dt) { expr(name) } else { - lit(NULL).cast(dt.clone()).alias(name) + lit(NULL).cast(dt.clone()).alias(name.clone()) } }) .collect() @@ -1028,7 +1064,7 @@ where pub(crate) fn maybe_init_projection_excluding_hive( reader_schema: &Either, hive_parts: Option<&HivePartitions>, -) -> Option> { +) -> Option> { // Update `with_columns` with a projection so that hive columns aren't loaded from the // file let hive_parts = hive_parts?; @@ -1037,21 +1073,22 @@ pub(crate) fn maybe_init_projection_excluding_hive( let (first_hive_name, _) = hive_schema.get_at_index(0)?; + // TODO: Optimize this let names = match reader_schema { - Either::Left(ref v) => { - let names = v.get_names(); - names.contains(&first_hive_name.as_str()).then_some(names) - }, - Either::Right(ref v) => v.contains(first_hive_name.as_str()).then(|| v.get_names()), + Either::Left(ref v) => v + .contains(first_hive_name.as_str()) + .then(|| v.iter_names_cloned().collect::>()), + Either::Right(ref v) => v + .contains(first_hive_name.as_str()) + .then(|| v.iter_names_cloned().collect()), }; let names = names?; Some( names - .iter() + .into_iter() .filter(|x| !hive_schema.contains(x)) - .map(ToString::to_string) .collect::>(), ) } diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index d069fc38ae85..b17db4c728d0 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -6,15 +6,16 @@ use super::*; pub(crate) fn prepare_projection( exprs: Vec, schema: &Schema, + opt_flags: &mut OptFlags, ) -> PolarsResult<(Vec, Schema)> { - let exprs = rewrite_projections(exprs, schema, &[])?; + let exprs = rewrite_projections(exprs, schema, &[], opt_flags)?; let schema = expressions_to_schema(&exprs, schema, Context::Default)?; Ok((exprs, schema)) } /// This replaces the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. -pub(super) fn replace_wildcard_with_column(expr: Expr, column_name: Arc) -> Expr { +pub(super) fn replace_wildcard_with_column(expr: Expr, column_name: &PlSmallStr) -> Expr { expr.map_expr(|e| match e { Expr::Wildcard => Expr::Column(column_name.clone()), Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), @@ -46,7 +47,7 @@ fn rewrite_special_aliases(expr: Expr) -> PolarsResult { Expr::RenameAlias { expr, function } => { let name = get_single_leaf(&expr).unwrap(); let name = function.call(&name)?; - Ok(Expr::Alias(expr, ColumnName::from(name))) + Ok(Expr::Alias(expr, name)) }, _ => { polars_bail!(InvalidOperation: "`keep`, `suffix`, `prefix` should be last expression") @@ -63,13 +64,12 @@ fn rewrite_special_aliases(expr: Expr) -> PolarsResult { fn replace_wildcard( expr: &Expr, result: &mut Vec, - exclude: &PlHashSet>, + exclude: &PlHashSet, schema: &Schema, ) -> PolarsResult<()> { for name in schema.iter_names() { if !exclude.contains(name.as_str()) { - let new_expr = - replace_wildcard_with_column(expr.clone(), ColumnName::from(name.as_str())); + let new_expr = replace_wildcard_with_column(expr.clone(), name); let new_expr = rewrite_special_aliases(new_expr)?; result.push(new_expr) } @@ -87,11 +87,11 @@ fn replace_nth(expr: Expr, schema: &Schema) -> Expr { -1 => "last", _ => "nth", }; - Expr::Column(ColumnName::from(name)) + Expr::Column(PlSmallStr::from_static(name)) }, Some(idx) => { let (name, _dtype) = schema.get_at_index(idx).unwrap(); - Expr::Column(ColumnName::from(&**name)) + Expr::Column(name.clone()) }, } } else { @@ -108,7 +108,7 @@ fn expand_regex( result: &mut Vec, schema: &Schema, pattern: &str, - exclude: &PlHashSet>, + exclude: &PlHashSet, ) -> PolarsResult<()> { let re = regex::Regex::new(pattern).map_err(|e| polars_err!(ComputeError: "invalid regex {}", e))?; @@ -117,9 +117,7 @@ fn expand_regex( let mut new_expr = remove_exclude(expr.clone()); new_expr = new_expr.map_expr(|e| match e { - Expr::Column(pat) if pat.as_ref() == pattern => { - Expr::Column(ColumnName::from(name.as_str())) - }, + Expr::Column(pat) if pat.as_str() == pattern => Expr::Column(name.clone()), e => e, }); @@ -141,7 +139,7 @@ fn replace_regex( expr: &Expr, result: &mut Vec, schema: &Schema, - exclude: &PlHashSet>, + exclude: &PlHashSet, ) -> PolarsResult<()> { let roots = expr_to_leaf_column_names(expr); let mut regex = None; @@ -174,9 +172,9 @@ fn replace_regex( fn expand_columns( expr: &Expr, result: &mut Vec, - names: &[ColumnName], + names: &[PlSmallStr], schema: &Schema, - exclude: &PlHashSet, + exclude: &PlHashSet, ) -> PolarsResult<()> { let mut is_valid = true; for name in names { @@ -215,12 +213,10 @@ fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult { polars_bail!(InvalidOperation: "expected 'struct' dtype, got {:?}", dtype) }; let index = index.try_negative_to_usize(fields.len())?; - let name = fields[index].name.as_str(); + let name = fields[index].name.clone(); Ok(Expr::Function { input, - function: FunctionExpr::StructExpr(StructFunction::FieldByName( - ColumnName::from(name), - )), + function: FunctionExpr::StructExpr(StructFunction::FieldByName(name)), options, }) } else { @@ -239,7 +235,7 @@ fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult { /// ()It also removes the Exclude Expr from the expression chain). fn replace_dtype_or_index_with_column( expr: Expr, - column_name: &ColumnName, + column_name: &PlSmallStr, replace_dtype: bool, ) -> Expr { expr.map_expr(|e| match e { @@ -254,8 +250,8 @@ fn replace_dtype_or_index_with_column( /// expression chain. pub(super) fn replace_columns_with_column( mut expr: Expr, - names: &[ColumnName], - column_name: &ColumnName, + names: &[PlSmallStr], + column_name: &PlSmallStr, ) -> (Expr, bool) { let mut is_valid = true; expr = expr.map_expr(|e| match e { @@ -294,7 +290,7 @@ fn expand_dtypes( result: &mut Vec, schema: &Schema, dtypes: &[DataType], - exclude: &PlHashSet>, + exclude: &PlHashSet, ) -> PolarsResult<()> { // note: we loop over the schema to guarantee that we return a stable // field-order, irrespective of which dtypes are filtered against @@ -304,8 +300,7 @@ fn expand_dtypes( }) { let name = field.name(); let new_expr = expr.clone(); - let new_expr = - replace_dtype_or_index_with_column(new_expr, &ColumnName::from(name.as_str()), true); + let new_expr = replace_dtype_or_index_with_column(new_expr, name, true); let new_expr = rewrite_special_aliases(new_expr)?; result.push(new_expr) } @@ -315,7 +310,7 @@ fn expand_dtypes( #[cfg(feature = "dtype-struct")] fn replace_struct_multiple_fields_with_field( expr: Expr, - column_name: &ColumnName, + column_name: &PlSmallStr, ) -> PolarsResult { let mut count = 0; let out = expr.map_expr(|e| match e { @@ -356,8 +351,8 @@ fn expand_struct_fields( full_expr: &Expr, result: &mut Vec, schema: &Schema, - names: &[ColumnName], - exclude: &PlHashSet>, + names: &[PlSmallStr], + exclude: &PlHashSet, ) -> PolarsResult<()> { let first_name = names[0].as_ref(); if names.len() == 1 && first_name == "*" || is_regex_projection(first_name) { @@ -365,7 +360,7 @@ fn expand_struct_fields( unreachable!() }; let field = input[0].to_field(schema, Context::Default)?; - let DataType::Struct(fields) = field.data_type() else { + let DataType::Struct(fields) = field.dtype() else { polars_bail!(InvalidOperation: "expected 'struct'") }; @@ -374,12 +369,12 @@ fn expand_struct_fields( fields .iter() .flat_map(|field| { - let name = field.name().as_str(); + let name = field.name(); - if exclude.contains(name) { + if exclude.contains(name.as_str()) { None } else { - Some(Arc::from(field.name().as_str())) + Some(name.clone()) } }) .collect::>() @@ -394,11 +389,11 @@ fn expand_struct_fields( fields .iter() .flat_map(|field| { - let name = field.name().as_str(); - if exclude.contains(name) || !re.is_match(name) { + let name = field.name(); + if exclude.contains(name.as_str()) || !re.is_match(name.as_str()) { None } else { - Some(Arc::from(field.name().as_str())) + Some(name.clone()) } }) .collect::>() @@ -409,11 +404,18 @@ fn expand_struct_fields( } }; - return expand_struct_fields(struct_expr, full_expr, result, schema, &names, exclude); + return expand_struct_fields( + struct_expr, + full_expr, + result, + schema, + names.as_slice(), + exclude, + ); } for name in names { - polars_ensure!(name.as_ref() != "*", InvalidOperation: "cannot combine wildcards and column names"); + polars_ensure!(name.as_str() != "*", InvalidOperation: "cannot combine wildcards and column names"); if !exclude.contains(name) { let mut new_expr = replace_struct_multiple_fields_with_field(full_expr.clone(), name)?; @@ -423,7 +425,7 @@ fn expand_struct_fields( }, Expr::RenameAlias { expr, function } => { let name = function.call(name)?; - new_expr = Expr::Alias(expr, ColumnName::from(name)); + new_expr = Expr::Alias(expr, name); }, _ => {}, } @@ -440,7 +442,7 @@ fn expand_indices( result: &mut Vec, schema: &Schema, indices: &[i64], - exclude: &PlHashSet>, + exclude: &PlHashSet, ) -> PolarsResult<()> { let n_fields = schema.len() as i64; for idx in indices { @@ -454,11 +456,7 @@ fn expand_indices( if let Some((name, _)) = schema.get_at_index(idx as usize) { if !exclude.contains(name.as_str()) { let new_expr = expr.clone(); - let new_expr = replace_dtype_or_index_with_column( - new_expr, - &ColumnName::from(name.as_str()), - false, - ); + let new_expr = replace_dtype_or_index_with_column(new_expr, name, false); let new_expr = rewrite_special_aliases(new_expr)?; result.push(new_expr); } @@ -474,7 +472,7 @@ fn prepare_excluded( schema: &Schema, keys: &[Expr], has_exclude: bool, -) -> PolarsResult>> { +) -> PolarsResult> { let mut exclude = PlHashSet::new(); // explicit exclude branch @@ -501,8 +499,8 @@ fn prepare_excluded( }, Excluded::Dtype(dt) => { for fld in schema.iter_fields() { - if dtypes_match(fld.data_type(), dt) { - exclude.insert(ColumnName::from(fld.name().as_ref())); + if dtypes_match(fld.dtype(), dt) { + exclude.insert(fld.name.clone()); } } }, @@ -520,7 +518,7 @@ fn prepare_excluded( Excluded::Dtype(dt) => { for (name, dtype) in schema.iter() { if matches!(dtype, dt) { - exclude.insert(ColumnName::from(name.as_str())); + exclude.insert(name.clone()); } } }, @@ -541,14 +539,18 @@ fn prepare_excluded( } // functions can have col(["a", "b"]) or col(String) as inputs -fn expand_function_inputs(expr: Expr, schema: &Schema) -> PolarsResult { +fn expand_function_inputs( + expr: Expr, + schema: &Schema, + opt_flags: &mut OptFlags, +) -> PolarsResult { expr.try_map_expr(|mut e| match &mut e { Expr::AnonymousFunction { input, options, .. } | Expr::Function { input, options, .. } if options .flags .contains(FunctionFlags::INPUT_WILDCARD_EXPANSION) => { - *input = rewrite_projections(core::mem::take(input), schema, &[]).unwrap(); + *input = rewrite_projections(core::mem::take(input), schema, &[], opt_flags).unwrap(); if input.is_empty() && !options.flags.contains(FunctionFlags::ALLOW_EMPTY_INPUTS) { // Needed to visualize the error *input = vec![Expr::Literal(LiteralValue::Null)]; @@ -639,12 +641,27 @@ fn find_flags(expr: &Expr) -> PolarsResult { }) } +#[cfg(feature = "dtype-struct")] +fn toggle_cse(opt_flags: &mut OptFlags) { + if opt_flags.contains(OptFlags::EAGER) { + #[cfg(debug_assertions)] + { + use polars_core::config::verbose; + if verbose() { + eprintln!("CSE turned on because of struct expansion") + } + } + *opt_flags |= OptFlags::COMM_SUBEXPR_ELIM; + } +} + /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns pub(crate) fn rewrite_projections( exprs: Vec, schema: &Schema, keys: &[Expr], + opt_flags: &mut OptFlags, ) -> PolarsResult> { let mut result = Vec::with_capacity(exprs.len() + schema.len()); @@ -653,7 +670,7 @@ pub(crate) fn rewrite_projections( let result_offset = result.len(); // Functions can have col(["a", "b"]) or col(String) as inputs. - expr = expand_function_inputs(expr, schema)?; + expr = expand_function_inputs(expr, schema, opt_flags)?; let mut flags = find_flags(&expr)?; if flags.has_selector { @@ -662,10 +679,11 @@ pub(crate) fn rewrite_projections( flags.multiple_columns = true; } - replace_and_add_to_results(expr, flags, &mut result, schema, keys)?; + replace_and_add_to_results(expr, flags, &mut result, schema, keys, opt_flags)?; #[cfg(feature = "dtype-struct")] if flags.has_struct_field_by_index { + toggle_cse(opt_flags); for e in &mut result[result_offset..] { *e = struct_index_to_field(std::mem::take(e), schema)?; } @@ -680,6 +698,7 @@ fn replace_and_add_to_results( result: &mut Vec, schema: &Schema, keys: &[Expr], + opt_flags: &mut OptFlags, ) -> PolarsResult<()> { if flags.has_nth { expr = replace_nth(expr, schema); @@ -732,6 +751,7 @@ fn replace_and_add_to_results( &mut intermediate, schema, keys, + opt_flags, )?; // Then expand the fields and add to the final result vec. @@ -739,12 +759,13 @@ fn replace_and_add_to_results( flags.multiple_columns = false; flags.has_wildcard = false; for e in intermediate { - replace_and_add_to_results(e, flags, result, schema, keys)?; + replace_and_add_to_results(e, flags, result, schema, keys, opt_flags)?; } } // has only field expansion // col('a').struct.field('*') else { + toggle_cse(opt_flags); expand_struct_fields(e, &expr, result, schema, names, &exclude)? } }, @@ -787,7 +808,14 @@ fn replace_selector_inner( match s { Selector::Root(expr) => { let local_flags = find_flags(&expr)?; - replace_and_add_to_results(*expr, local_flags, scratch, schema, keys)?; + replace_and_add_to_results( + *expr, + local_flags, + scratch, + schema, + keys, + &mut Default::default(), + )?; members.extend(scratch.drain(..)) }, Selector::Add(lhs, rhs) => { @@ -847,11 +875,11 @@ fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult< }) } -pub(super) fn expand_selectors( +pub(crate) fn expand_selectors( s: Vec, schema: &Schema, keys: &[Expr], -) -> PolarsResult> { +) -> PolarsResult> { let mut columns = vec![]; // Skip the column fast paths. @@ -889,25 +917,25 @@ pub(super) fn expand_selector( s: Selector, schema: &Schema, keys: &[Expr], -) -> PolarsResult> { +) -> PolarsResult> { let mut members = PlIndexSet::new(); replace_selector_inner(s, &mut members, &mut vec![], schema, keys)?; if members.len() <= 1 { - Ok(members + members .into_iter() .map(|e| { let Expr::Column(name) = e else { - unreachable!() + polars_bail!(InvalidOperation: "invalid selector expression: {}", e) }; - name + Ok(name) }) - .collect()) + .collect() } else { // Ensure that multiple columns returned from combined/nested selectors remain in schema order let selected = schema .iter_fields() - .map(|field| ColumnName::from(field.name().as_ref())) + .map(|field| field.name().clone()) .filter(|field_name| members.contains(&Expr::Column(field_name.clone()))) .collect(); diff --git a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs index f61357d203b7..1e6457eed810 100644 --- a/crates/polars-plan/src/plans/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/expr_to_ir.rs @@ -1,32 +1,39 @@ use super::*; +use crate::plans::conversion::functions::convert_functions; -pub fn to_expr_ir(expr: Expr, arena: &mut Arena) -> ExprIR { - let mut state = ConversionState::new(); - let node = to_aexpr_impl(expr, arena, &mut state); - ExprIR::new(node, state.output_name) +pub fn to_expr_ir(expr: Expr, arena: &mut Arena) -> PolarsResult { + let mut state = ConversionContext::new(); + let node = to_aexpr_impl(expr, arena, &mut state)?; + Ok(ExprIR::new(node, state.output_name)) } -pub(super) fn to_expr_irs(input: Vec, arena: &mut Arena) -> Vec { - input.convert_owned(|e| to_expr_ir(e, arena)) +pub(super) fn to_expr_irs(input: Vec, arena: &mut Arena) -> PolarsResult> { + input.into_iter().map(|e| to_expr_ir(e, arena)).collect() } -pub fn to_expr_ir_ignore_alias(expr: Expr, arena: &mut Arena) -> ExprIR { - let mut state = ConversionState::new(); +pub fn to_expr_ir_ignore_alias(expr: Expr, arena: &mut Arena) -> PolarsResult { + let mut state = ConversionContext::new(); state.ignore_alias = true; - let node = to_aexpr_impl_materialized_lit(expr, arena, &mut state); - ExprIR::new(node, state.output_name) + let node = to_aexpr_impl_materialized_lit(expr, arena, &mut state)?; + Ok(ExprIR::new(node, state.output_name)) } -pub(super) fn to_expr_irs_ignore_alias(input: Vec, arena: &mut Arena) -> Vec { - input.convert_owned(|e| to_expr_ir_ignore_alias(e, arena)) +pub(super) fn to_expr_irs_ignore_alias( + input: Vec, + arena: &mut Arena, +) -> PolarsResult> { + input + .into_iter() + .map(|e| to_expr_ir_ignore_alias(e, arena)) + .collect() } /// converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation -pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { +pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> PolarsResult { to_aexpr_impl_materialized_lit( expr, arena, - &mut ConversionState { + &mut ConversionContext { prune_alias: false, ..Default::default() }, @@ -34,15 +41,15 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { } #[derive(Default)] -struct ConversionState { - output_name: OutputName, +pub(super) struct ConversionContext { + pub(super) output_name: OutputName, /// Remove alias from the expressions and set as [`OutputName`]. - prune_alias: bool, + pub(super) prune_alias: bool, /// If an `alias` is encountered prune and ignore it. - ignore_alias: bool, + pub(super) ignore_alias: bool, } -impl ConversionState { +impl ConversionContext { fn new() -> Self { Self { prune_alias: true, @@ -51,20 +58,28 @@ impl ConversionState { } } -fn to_aexprs(input: Vec, arena: &mut Arena, state: &mut ConversionState) -> Vec { +fn to_aexprs( + input: Vec, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult> { input .into_iter() .map(|e| to_aexpr_impl_materialized_lit(e, arena, state)) .collect() } -fn set_function_output_name(e: &[ExprIR], state: &mut ConversionState, function_fmt: F) -where - F: FnOnce() -> Cow<'static, str>, +pub(super) fn set_function_output_name( + e: &[ExprIR], + state: &mut ConversionContext, + function_fmt: F, +) where + F: FnOnce() -> PlSmallStr, { if state.output_name.is_none() { if e.is_empty() { - state.output_name = OutputName::LiteralLhs(ColumnName::from(function_fmt().as_ref())); + let s = function_fmt(); + state.output_name = OutputName::LiteralLhs(s); } else { state.output_name = e[0].output_name_inner().clone(); } @@ -74,8 +89,8 @@ where fn to_aexpr_impl_materialized_lit( expr: Expr, arena: &mut Arena, - state: &mut ConversionState, -) -> Node { + state: &mut ConversionContext, +) -> PolarsResult { // Already convert `Lit Float and Lit Int` expressions that are not used in a binary / function expression. // This means they can be materialized immediately let e = match expr { @@ -106,24 +121,28 @@ fn to_aexpr_impl_materialized_lit( /// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation. #[recursive] -fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionState) -> Node { +pub(super) fn to_aexpr_impl( + expr: Expr, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult { let owned = Arc::unwrap_or_clone; let v = match expr { - Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(owned(expr), arena, state)), + Expr::Explode(expr) => AExpr::Explode(to_aexpr_impl(owned(expr), arena, state)?), Expr::Alias(e, name) => { if state.prune_alias { if state.output_name.is_none() && !state.ignore_alias { state.output_name = OutputName::Alias(name); } - to_aexpr_impl(owned(e), arena, state); + let _ = to_aexpr_impl(owned(e), arena, state)?; arena.pop().unwrap() } else { - AExpr::Alias(to_aexpr_impl(owned(e), arena, state), name) + AExpr::Alias(to_aexpr_impl(owned(e), arena, state)?, name) } }, Expr::Literal(lv) => { if state.output_name.is_none() { - state.output_name = OutputName::LiteralLhs(lv.output_column_name()); + state.output_name = OutputName::LiteralLhs(lv.output_column_name().clone()); } AExpr::Literal(lv) }, @@ -134,8 +153,8 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta AExpr::Column(name) }, Expr::BinaryExpr { left, op, right } => { - let l = to_aexpr_impl(owned(left), arena, state); - let r = to_aexpr_impl(owned(right), arena, state); + let l = to_aexpr_impl(owned(left), arena, state)?; + let r = to_aexpr_impl(owned(right), arena, state)?; AExpr::BinaryExpr { left: l, op, @@ -144,11 +163,11 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta }, Expr::Cast { expr, - data_type, + dtype, options, } => AExpr::Cast { - expr: to_aexpr_impl(owned(expr), arena, state), - data_type, + expr: to_aexpr_impl(owned(expr), arena, state)?, + dtype, options, }, Expr::Gather { @@ -156,12 +175,12 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta idx, returns_scalar, } => AExpr::Gather { - expr: to_aexpr_impl(owned(expr), arena, state), - idx: to_aexpr_impl_materialized_lit(owned(idx), arena, state), + expr: to_aexpr_impl(owned(expr), arena, state)?, + idx: to_aexpr_impl_materialized_lit(owned(idx), arena, state)?, returns_scalar, }, Expr::Sort { expr, options } => AExpr::Sort { - expr: to_aexpr_impl(owned(expr), arena, state), + expr: to_aexpr_impl(owned(expr), arena, state)?, options, }, Expr::SortBy { @@ -169,16 +188,16 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta by, sort_options, } => AExpr::SortBy { - expr: to_aexpr_impl(owned(expr), arena, state), + expr: to_aexpr_impl(owned(expr), arena, state)?, by: by .into_iter() .map(|e| to_aexpr_impl(e, arena, state)) - .collect(), + .collect::>()?, sort_options, }, Expr::Filter { input, by } => AExpr::Filter { - input: to_aexpr_impl(owned(input), arena, state), - by: to_aexpr_impl(owned(by), arena, state), + input: to_aexpr_impl(owned(input), arena, state)?, + by: to_aexpr_impl(owned(by), arena, state)?, }, Expr::Agg(agg) => { let a_agg = match agg { @@ -186,36 +205,36 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta input, propagate_nans, } => IRAggExpr::Min { - input: to_aexpr_impl_materialized_lit(owned(input), arena, state), + input: to_aexpr_impl_materialized_lit(owned(input), arena, state)?, propagate_nans, }, AggExpr::Max { input, propagate_nans, } => IRAggExpr::Max { - input: to_aexpr_impl_materialized_lit(owned(input), arena, state), + input: to_aexpr_impl_materialized_lit(owned(input), arena, state)?, propagate_nans, }, AggExpr::Median(expr) => { - IRAggExpr::Median(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::Median(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::NUnique(expr) => { - IRAggExpr::NUnique(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::NUnique(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::First(expr) => { - IRAggExpr::First(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::First(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::Last(expr) => { - IRAggExpr::Last(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::Last(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::Mean(expr) => { - IRAggExpr::Mean(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::Mean(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::Implode(expr) => { - IRAggExpr::Implode(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::Implode(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::Count(expr, include_nulls) => IRAggExpr::Count( - to_aexpr_impl_materialized_lit(owned(expr), arena, state), + to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, include_nulls, ), AggExpr::Quantile { @@ -223,23 +242,23 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta quantile, interpol, } => IRAggExpr::Quantile { - expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state), - quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state), + expr: to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, + quantile: to_aexpr_impl_materialized_lit(owned(quantile), arena, state)?, interpol, }, AggExpr::Sum(expr) => { - IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::Sum(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, AggExpr::Std(expr, ddof) => IRAggExpr::Std( - to_aexpr_impl_materialized_lit(owned(expr), arena, state), + to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, ddof, ), AggExpr::Var(expr, ddof) => IRAggExpr::Var( - to_aexpr_impl_materialized_lit(owned(expr), arena, state), + to_aexpr_impl_materialized_lit(owned(expr), arena, state)?, ddof, ), AggExpr::AggGroups(expr) => { - IRAggExpr::AggGroups(to_aexpr_impl_materialized_lit(owned(expr), arena, state)) + IRAggExpr::AggGroups(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?) }, }; AExpr::Agg(a_agg) @@ -250,9 +269,9 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta falsy, } => { // Truthy must be resolved first to get the lhs name first set. - let t = to_aexpr_impl(owned(truthy), arena, state); - let p = to_aexpr_impl_materialized_lit(owned(predicate), arena, state); - let f = to_aexpr_impl(owned(falsy), arena, state); + let t = to_aexpr_impl(owned(truthy), arena, state)?; + let p = to_aexpr_impl_materialized_lit(owned(predicate), arena, state)?; + let f = to_aexpr_impl(owned(falsy), arena, state)?; AExpr::Ternary { predicate: p, truthy: t, @@ -265,8 +284,8 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta output_type, options, } => { - let e = to_expr_irs(input, arena); - set_function_output_name(&e, state, || Cow::Borrowed(options.fmt_str)); + let e = to_expr_irs(input, arena)?; + set_function_output_name(&e, state, || PlSmallStr::from_static(options.fmt_str)); AExpr::AnonymousFunction { input: e, function, @@ -278,74 +297,34 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta input, function, options, - } => { - match function { - // This can be created by col(*).is_null() on empty dataframes. - FunctionExpr::Boolean( - BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal, - ) if input.is_empty() => { - return to_aexpr_impl(lit(true), arena, state); - }, - // Convert to binary expression as the optimizer understands those. - // Don't exceed 128 expressions as we might stackoverflow. - FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => { - if input.len() < 128 { - let expr = input - .into_iter() - .reduce(|l, r| l.logical_and(r)) - .unwrap() - .cast(DataType::Boolean); - return to_aexpr_impl(expr, arena, state); - } - }, - FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => { - if input.len() < 128 { - let expr = input - .into_iter() - .reduce(|l, r| l.logical_or(r)) - .unwrap() - .cast(DataType::Boolean); - return to_aexpr_impl(expr, arena, state); - } - }, - _ => {}, - } - - let e = to_expr_irs(input, arena); - - if state.output_name.is_none() { - // Handles special case functions like `struct.field`. - if let Some(name) = function.output_name() { - state.output_name = name - } else { - set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function))); - } - } - AExpr::Function { - input: e, - function, - options, - } - }, + } => return convert_functions(input, function, options, arena, state), Expr::Window { function, partition_by, order_by, options, - } => AExpr::Window { - function: to_aexpr_impl(owned(function), arena, state), - partition_by: to_aexprs(partition_by, arena, state), - order_by: order_by.map(|(e, options)| (to_aexpr_impl(owned(e), arena, state), options)), - options, + } => { + let order_by = if let Some((e, options)) = order_by { + Some((to_aexpr_impl(owned(e.clone()), arena, state)?, options)) + } else { + None + }; + + AExpr::Window { + function: to_aexpr_impl(owned(function), arena, state)?, + partition_by: to_aexprs(partition_by, arena, state)?, + order_by, + options, + } }, Expr::Slice { input, offset, length, } => AExpr::Slice { - input: to_aexpr_impl(owned(input), arena, state), - offset: to_aexpr_impl_materialized_lit(owned(offset), arena, state), - length: to_aexpr_impl_materialized_lit(owned(length), arena, state), + input: to_aexpr_impl(owned(input), arena, state)?, + offset: to_aexpr_impl_materialized_lit(owned(offset), arena, state)?, + length: to_aexpr_impl_materialized_lit(owned(length), arena, state)?, }, Expr::Len => { if state.output_name.is_none() { @@ -353,24 +332,22 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta } AExpr::Len }, - Expr::Nth(i) => AExpr::Nth(i), - Expr::IndexColumn(idx) => { - if idx.len() == 1 { - AExpr::Nth(idx[0]) - } else { - panic!("no multi-value `index-columns` expected at this point") - } - }, - Expr::Wildcard => AExpr::Wildcard, #[cfg(feature = "dtype-struct")] - Expr::Field(_) => unreachable!(), // replaced during expansion - Expr::SubPlan { .. } => panic!("no SQL subquery expected at this point"), - Expr::KeepName(_) => panic!("no `name.keep` expected at this point"), - Expr::Exclude(_, _) => panic!("no `exclude` expected at this point"), - Expr::RenameAlias { .. } => panic!("no `rename_alias` expected at this point"), - Expr::Columns { .. } => panic!("no `columns` expected at this point"), - Expr::DtypeColumn { .. } => panic!("no `dtype-columns` expected at this point"), - Expr::Selector(_) => panic!("no `selector` expected at this point"), + e @ Expr::Field(_) => { + polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e) + }, + e @ Expr::IndexColumn(_) + | e @ Expr::Wildcard + | e @ Expr::Nth(_) + | e @ Expr::SubPlan { .. } + | e @ Expr::KeepName(_) + | e @ Expr::Exclude(_, _) + | e @ Expr::RenameAlias { .. } + | e @ Expr::Columns { .. } + | e @ Expr::DtypeColumn { .. } + | e @ Expr::Selector(_) => { + polars_bail!(InvalidOperation: "'Expr: {}' not allowed in this context/location", e) + }, }; - arena.add(v) + Ok(arena.add(v)) } diff --git a/crates/polars-plan/src/plans/conversion/functions.rs b/crates/polars-plan/src/plans/conversion/functions.rs new file mode 100644 index 000000000000..f77cb0e79a95 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/functions.rs @@ -0,0 +1,67 @@ +use arrow::legacy::error::PolarsResult; +use polars_utils::arena::{Arena, Node}; +use polars_utils::format_pl_smallstr; + +use super::*; +use crate::dsl::{Expr, FunctionExpr}; +use crate::plans::AExpr; +use crate::prelude::FunctionOptions; + +pub(super) fn convert_functions( + input: Vec, + function: FunctionExpr, + options: FunctionOptions, + arena: &mut Arena, + state: &mut ConversionContext, +) -> PolarsResult { + match function { + // This can be created by col(*).is_null() on empty dataframes. + FunctionExpr::Boolean(BooleanFunction::AllHorizontal) if input.is_empty() => { + return to_aexpr_impl(lit(true), arena, state); + }, + FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) if input.is_empty() => { + return to_aexpr_impl(lit(false), arena, state); + }, + // Convert to binary expression as the optimizer understands those. + // Don't exceed 128 expressions as we might stackoverflow. + FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => { + if input.len() < 128 { + let expr = input + .into_iter() + .reduce(|l, r| l.logical_and(r)) + .unwrap() + .cast(DataType::Boolean); + return to_aexpr_impl(expr, arena, state); + } + }, + FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => { + if input.len() < 128 { + let expr = input + .into_iter() + .reduce(|l, r| l.logical_or(r)) + .unwrap() + .cast(DataType::Boolean); + return to_aexpr_impl(expr, arena, state); + } + }, + _ => {}, + } + + let e = to_expr_irs(input, arena)?; + + if state.output_name.is_none() { + // Handles special case functions like `struct.field`. + if let Some(name) = function.output_name() { + state.output_name = name + } else { + set_function_output_name(&e, state, || format_pl_smallstr!("{}", &function)); + } + } + + let ae_function = AExpr::Function { + input: e, + function, + options, + }; + Ok(arena.add(ae_function)) +} diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index a7c2fac17edf..c90590914e47 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -24,13 +24,13 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { }, AExpr::Cast { expr, - data_type, + dtype, options: strict, } => { let exp = node_to_expr(expr, expr_arena); Expr::Cast { expr: Arc::new(exp), - data_type, + dtype, options: strict, } }, @@ -223,8 +223,6 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { length: Arc::new(node_to_expr(length, expr_arena)), }, AExpr::Len => Expr::Len, - AExpr::Nth(i) => Expr::Nth(i), - AExpr::Wildcard => Expr::Wildcard, } } diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs new file mode 100644 index 000000000000..6c3e28bb6c7a --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -0,0 +1,339 @@ +use arrow::legacy::error::PolarsResult; + +use super::*; +use crate::dsl::Expr; +use crate::plans::AExpr; + +fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> { + for e in keys { + if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) { + polars_bail!( + InvalidOperation: + "'alias' is not allowed in a join key, use 'with_columns' first", + ) + } + } + Ok(()) +} +pub fn resolve_join( + input_left: Arc, + input_right: Arc, + left_on: Vec, + right_on: Vec, + predicates: Vec, + mut options: Arc, + ctxt: &mut DslConversionContext, +) -> PolarsResult { + if !predicates.is_empty() { + debug_assert!(left_on.is_empty() && right_on.is_empty()); + return resolve_join_where(input_left, input_right, predicates, options, ctxt); + } + + let owned = Arc::unwrap_or_clone; + if matches!(options.args.how, JoinType::Cross) { + polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); + } else { + check_join_keys(&left_on)?; + check_join_keys(&right_on)?; + + let mut turn_off_coalesce = false; + for e in left_on.iter().chain(right_on.iter()) { + // Any expression that is not a simple column expression will turn of coalescing. + turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_))); + } + if turn_off_coalesce { + let options = Arc::make_mut(&mut options); + if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) { + polars_warn!("coalescing join requested but not all join keys are column references, turning off key coalescing"); + } + options.args.coalesce = JoinCoalesce::KeepColumns; + } + + options.args.validation.is_valid_join(&options.args.how)?; + + polars_ensure!( + left_on.len() == right_on.len(), + InvalidOperation: + format!( + "the number of columns given as join key (left: {}, right:{}) should be equal", + left_on.len(), + right_on.len() + ) + ); + } + + let input_left = + to_alp_impl(owned(input_left), ctxt).map_err(|e| e.context(failed_input!(join left)))?; + let input_right = + to_alp_impl(owned(input_right), ctxt).map_err(|e| e.context(failed_input!(join, right)))?; + + let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); + let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); + + let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options) + .map_err(|e| e.context(failed_here!(join schema resolving)))?; + + let left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?; + let right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?; + let mut joined_on = PlHashSet::new(); + for (l, r) in left_on.iter().zip(right_on.iter()) { + polars_ensure!( + joined_on.insert((l.output_name(), r.output_name())), + InvalidOperation: "joining with repeated key names; already joined on {} and {}", + l.output_name(), + r.output_name() + ) + } + drop(joined_on); + + ctxt.conversion_optimizer + .fill_scratch(&left_on, ctxt.expr_arena); + ctxt.conversion_optimizer + .fill_scratch(&right_on, ctxt.expr_arena); + + // Every expression must be elementwise so that we are + // guaranteed the keys for a join are all the same length. + let all_elementwise = + |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Context::Default); + polars_ensure!( + all_elementwise(&left_on) && all_elementwise(&right_on), + InvalidOperation: "All join key expressions must be elementwise." + ); + let lp = IR::Join { + input_left, + input_right, + schema, + left_on, + right_on, + options, + }; + run_conversion(lp, ctxt, "join") +} + +impl From for Operator { + fn from(value: InequalityOperator) -> Self { + match value { + InequalityOperator::LtEq => Operator::LtEq, + InequalityOperator::Lt => Operator::Lt, + InequalityOperator::GtEq => Operator::GtEq, + InequalityOperator::Gt => Operator::Gt, + } + } +} + +fn resolve_join_where( + input_left: Arc, + input_right: Arc, + predicates: Vec, + mut options: Arc, + ctxt: &mut DslConversionContext, +) -> PolarsResult { + check_join_keys(&predicates)?; + + for e in &predicates { + let no_binary_comparisons = e + .into_iter() + .filter(|e| match e { + Expr::BinaryExpr { op, .. } => op.is_comparison(), + _ => false, + }) + .count(); + polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition") + } + + let owned = |e: Arc| (*e).clone(); + + // Partition to: + // - IEjoin supported inequality predicates + // - equality predicates + // - remaining predicates + let mut ie_left_on = vec![]; + let mut ie_right_on = vec![]; + let mut ie_op = vec![]; + + let mut eq_left_on = vec![]; + let mut eq_right_on = vec![]; + + let mut remaining_preds = vec![]; + + fn to_inequality_operator(op: &Operator) -> Option { + match op { + Operator::Lt => Some(InequalityOperator::Lt), + Operator::LtEq => Some(InequalityOperator::LtEq), + Operator::Gt => Some(InequalityOperator::Gt), + Operator::GtEq => Some(InequalityOperator::GtEq), + _ => None, + } + } + + for pred in predicates.into_iter() { + let Expr::BinaryExpr { left, op, right } = pred.clone() else { + polars_bail!(InvalidOperation: "can only join on binary expressions") + }; + polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); + + if let Some(ie_op_) = to_inequality_operator(&op) { + // We already have an IEjoin or an Inner join, push to remaining + if ie_op.len() >= 2 || !eq_right_on.is_empty() { + remaining_preds.push(Expr::BinaryExpr { left, op, right }) + } else { + ie_left_on.push(owned(left)); + ie_right_on.push(owned(right)); + ie_op.push(ie_op_) + } + } else if matches!(op, Operator::Eq) { + eq_left_on.push(owned(left)); + eq_right_on.push(owned(right)); + } else { + remaining_preds.push(pred); + } + } + + // Now choose a primary join and do the remaining predicates as filters + fn to_binary(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Arc::from(l), + op, + right: Arc::from(r), + } + } + // Add the ie predicates to the remaining predicates buffer so that they will be executed in the + // filter node. + fn ie_predicates_to_remaining( + remaining_preds: &mut Vec, + ie_left_on: Vec, + ie_right_on: Vec, + ie_op: Vec, + ) { + for ((l, op), r) in ie_left_on + .into_iter() + .zip(ie_op.into_iter()) + .zip(ie_right_on.into_iter()) + { + remaining_preds.push(to_binary(l, op.into(), r)) + } + } + + let join_node = if !eq_left_on.is_empty() { + // We found one or more equality predicates. Go into a default equi join + // as those are cheapest on avg. + let join_node = resolve_join( + input_left, + input_right, + eq_left_on, + eq_right_on, + vec![], + options.clone(), + ctxt, + )?; + + ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op); + join_node + } + // TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case. + else if ie_right_on.len() >= 2 { + // Do an IEjoin. + let opts = Arc::make_mut(&mut options); + opts.args.how = JoinType::IEJoin(IEJoinOptions { + operator1: ie_op[0], + operator2: ie_op[1], + }); + + let join_node = resolve_join( + input_left, + input_right, + ie_left_on[..2].to_vec(), + ie_right_on[..2].to_vec(), + vec![], + options.clone(), + ctxt, + )?; + + // The surplus ie-predicates will be added to the remaining predicates so that + // they will be applied in a filter node. + while ie_right_on.len() > 2 { + // Invariant: they all have equal length, so we can pop and unwrap all while len > 2. + // The first 2 predicates are used in the + let l = ie_right_on.pop().unwrap(); + let r = ie_left_on.pop().unwrap(); + let op = ie_op.pop().unwrap(); + + remaining_preds.push(to_binary(l, op.into(), r)) + } + join_node + } else { + // No predicates found that are supported in a fast algorithm. + // Do a cross join and follow up with filters. + let opts = Arc::make_mut(&mut options); + opts.args.how = JoinType::Cross; + + let join_node = resolve_join( + input_left, + input_right, + vec![], + vec![], + vec![], + options.clone(), + ctxt, + )?; + // TODO: This can be removed once we support the single IEjoin. + ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op); + join_node + }; + + let IR::Join { + input_left, + input_right, + .. + } = ctxt.lp_arena.get(join_node) + else { + unreachable!() + }; + let schema_right = ctxt + .lp_arena + .get(*input_right) + .schema(ctxt.lp_arena) + .into_owned(); + + let schema_left = ctxt + .lp_arena + .get(*input_left) + .schema(ctxt.lp_arena) + .into_owned(); + + let suffix = options.args.suffix(); + + let mut last_node = join_node; + + // Ensure that the predicates use the proper suffix + for e in remaining_preds { + let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?; + let AExpr::BinaryExpr { mut right, .. } = *ctxt.expr_arena.get(predicate.node()) else { + unreachable!() + }; + + let original_right = right; + + for name in aexpr_to_leaf_names(right, ctxt.expr_arena) { + polars_ensure!(schema_right.contains(name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation"); + if schema_left.contains(name.as_str()) { + let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); + + right = rename_matching_aexpr_leaf_names( + right, + ctxt.expr_arena, + name.as_str(), + new_name, + ); + } + } + ctxt.expr_arena.swap(right, original_right); + + let ir = IR::Filter { + input: last_node, + predicate, + }; + last_node = ctxt.lp_arena.add(ir); + } + Ok(last_node) +} diff --git a/crates/polars-plan/src/plans/conversion/mod.rs b/crates/polars-plan/src/plans/conversion/mod.rs index 275109cbc613..b9ed8711a438 100644 --- a/crates/polars-plan/src/plans/conversion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/mod.rs @@ -3,11 +3,15 @@ mod dsl_to_ir; mod expr_expansion; mod expr_to_ir; mod ir_to_dsl; -#[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] +#[cfg(any( + feature = "ipc", + feature = "parquet", + feature = "csv", + feature = "json" +))] mod scans; mod stack_opt; -use std::borrow::Cow; use std::sync::{Arc, Mutex, RwLock}; pub use dsl_to_ir::*; @@ -16,9 +20,11 @@ pub use ir_to_dsl::*; use polars_core::prelude::*; use polars_utils::vec::ConvertVec; use recursive::recursive; +mod functions; +mod join; pub(crate) mod type_coercion; -pub(crate) use expr_expansion::{is_regex_projection, prepare_projection, rewrite_projections}; +pub(crate) use expr_expansion::{expand_selectors, is_regex_projection, prepare_projection}; use crate::constants::get_len_name; use crate::prelude::*; @@ -44,7 +50,7 @@ impl IR { }; match lp { IR::Scan { - paths, + sources, file_info, hive_parts, predicate, @@ -52,7 +58,10 @@ impl IR { output_schema: _, file_options: options, } => DslPlan::Scan { - paths: Arc::new(Mutex::new((paths, true))), + sources: Arc::new(Mutex::new(DslScanSources { + sources, + is_expanded: true, + })), file_info: Arc::new(RwLock::new(Some(file_info))), hive_parts, predicate: predicate.map(|e| e.to_expr(expr_arena)), @@ -136,7 +145,7 @@ impl IR { let input = convert_to_lp(input, lp_arena); let expr = columns .iter_names() - .map(|name| Expr::Column(ColumnName::from(name.as_str()))) + .map(|name| Expr::Column(name.clone())) .collect::>(); DslPlan::Select { expr, @@ -210,6 +219,7 @@ impl IR { DslPlan::Join { input_left: Arc::new(i_l), input_right: Arc::new(i_r), + predicates: Default::default(), left_on, right_on, options, @@ -232,6 +242,15 @@ impl IR { }, IR::Distinct { input, options } => { let i = convert_to_lp(input, lp_arena); + let options = DistinctOptionsDSL { + subset: options.subset.map(|s| { + s.iter() + .map(|name| Expr::Column(name.clone()).into()) + .collect() + }), + maintain_order: options.maintain_order, + keep_strategy: options.keep_strategy, + }; DslPlan::Distinct { input: Arc::new(i), options, diff --git a/crates/polars-plan/src/plans/conversion/scans.rs b/crates/polars-plan/src/plans/conversion/scans.rs index 959327148f6c..25dd61aa1eb9 100644 --- a/crates/polars-plan/src/plans/conversion/scans.rs +++ b/crates/polars-plan/src/plans/conversion/scans.rs @@ -1,5 +1,3 @@ -use std::path::PathBuf; - use either::Either; use polars_io::path_utils::is_cloud_url; #[cfg(feature = "cloud")] @@ -9,17 +7,10 @@ use polars_io::RowIndex; use super::*; -fn get_first_path(paths: &[PathBuf]) -> PolarsResult<&PathBuf> { - // Use first path to get schema. - paths - .first() - .ok_or_else(|| polars_err!(ComputeError: "expected at least 1 path")) -} - #[cfg(any(feature = "parquet", feature = "ipc"))] fn prepare_output_schema(mut schema: Schema, row_index: Option<&RowIndex>) -> SchemaRef { if let Some(rc) = row_index { - let _ = schema.insert_at_index(0, rc.name.as_ref().into(), IDX_DTYPE); + let _ = schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE); } Arc::new(schema) } @@ -28,7 +19,7 @@ fn prepare_output_schema(mut schema: Schema, row_index: Option<&RowIndex>) -> Sc fn prepare_schemas(mut schema: Schema, row_index: Option<&RowIndex>) -> (SchemaRef, SchemaRef) { if let Some(rc) = row_index { let reader_schema = schema.clone(); - let _ = schema.insert_at_index(0, rc.name.as_ref().into(), IDX_DTYPE); + let _ = schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE); (Arc::new(reader_schema), Arc::new(schema)) } else { let schema = Arc::new(schema); @@ -38,44 +29,47 @@ fn prepare_schemas(mut schema: Schema, row_index: Option<&RowIndex>) -> (SchemaR #[cfg(feature = "parquet")] pub(super) fn parquet_file_info( - paths: &[PathBuf], + sources: &ScanSources, file_options: &FileScanOptions, - cloud_options: Option<&polars_io::cloud::CloudOptions>, + #[allow(unused)] cloud_options: Option<&polars_io::cloud::CloudOptions>, ) -> PolarsResult<(FileInfo, Option)> { - let path = get_first_path(paths)?; - - let (schema, reader_schema, num_rows, metadata) = if is_cloud_url(path) { - #[cfg(not(feature = "cloud"))] - panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); - - #[cfg(feature = "cloud")] - { - let uri = path.to_string_lossy(); - get_runtime().block_on(async { - let mut reader = ParquetAsyncReader::from_uri(&uri, cloud_options, None).await?; - let reader_schema = reader.schema().await?; - let num_rows = reader.num_rows().await?; - let metadata = reader.get_metadata().await?.clone(); - - let schema = - prepare_output_schema((&reader_schema).into(), file_options.row_index.as_ref()); - PolarsResult::Ok((schema, reader_schema, Some(num_rows), Some(metadata))) - })? + use polars_core::error::feature_gated; + + let (reader_schema, num_rows, metadata) = { + if sources.is_cloud_url() { + let first_path = &sources.as_paths().unwrap()[0]; + feature_gated!("cloud", { + let uri = first_path.to_string_lossy(); + get_runtime().block_on(async { + let mut reader = + ParquetAsyncReader::from_uri(&uri, cloud_options, None).await?; + + PolarsResult::Ok(( + reader.schema().await?, + Some(reader.num_rows().await?), + Some(reader.get_metadata().await?.clone()), + )) + })? + }) + } else { + let first_source = sources + .first() + .ok_or_else(|| polars_err!(ComputeError: "expected at least 1 source"))?; + let memslice = first_source.to_memslice()?; + let mut reader = ParquetReader::new(std::io::Cursor::new(memslice)); + ( + reader.schema()?, + Some(reader.num_rows()?), + Some(reader.get_metadata()?.clone()), + ) } - } else { - let file = polars_utils::open_file(path)?; - let mut reader = ParquetReader::new(file); - let reader_schema = reader.schema()?; - let schema = - prepare_output_schema((&reader_schema).into(), file_options.row_index.as_ref()); - ( - schema, - reader_schema, - Some(reader.num_rows()?), - Some(reader.get_metadata()?.clone()), - ) }; + let schema = prepare_output_schema( + Schema::from_arrow_schema(reader_schema.as_ref()), + file_options.row_index.as_ref(), + ); + let file_info = FileInfo::new( schema, Some(Either::Left(reader_schema)), @@ -88,34 +82,45 @@ pub(super) fn parquet_file_info( // TODO! return metadata arced #[cfg(feature = "ipc")] pub(super) fn ipc_file_info( - paths: &[PathBuf], + sources: &ScanSources, file_options: &FileScanOptions, cloud_options: Option<&polars_io::cloud::CloudOptions>, ) -> PolarsResult<(FileInfo, arrow::io::ipc::read::FileMetadata)> { - let path = get_first_path(paths)?; - - let metadata = if is_cloud_url(path) { - #[cfg(not(feature = "cloud"))] - panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); - - #[cfg(feature = "cloud")] - { - let uri = path.to_string_lossy(); - get_runtime().block_on(async { - polars_io::ipc::IpcReaderAsync::from_uri(&uri, cloud_options) - .await? - .metadata() - .await - })? - } - } else { - arrow::io::ipc::read::read_file_metadata(&mut std::io::BufReader::new( - polars_utils::open_file(path)?, - ))? + use polars_core::error::feature_gated; + + let Some(first) = sources.first() else { + polars_bail!(ComputeError: "expected at least 1 source"); }; + + let metadata = match first { + ScanSourceRef::Path(path) => { + if is_cloud_url(path) { + feature_gated!("cloud", { + let uri = path.to_string_lossy(); + get_runtime().block_on(async { + polars_io::ipc::IpcReaderAsync::from_uri(&uri, cloud_options) + .await? + .metadata() + .await + })? + }) + } else { + arrow::io::ipc::read::read_file_metadata(&mut std::io::BufReader::new( + polars_utils::open_file(path)?, + ))? + } + }, + ScanSourceRef::File(file) => { + arrow::io::ipc::read::read_file_metadata(&mut std::io::BufReader::new(file))? + }, + ScanSourceRef::Buffer(buff) => { + arrow::io::ipc::read::read_file_metadata(&mut std::io::Cursor::new(buff))? + }, + }; + let file_info = FileInfo::new( prepare_output_schema( - metadata.schema.as_ref().into(), + Schema::from_arrow_schema(metadata.schema.as_ref()), file_options.row_index.as_ref(), ), Some(Either::Left(Arc::clone(&metadata.schema))), @@ -127,115 +132,94 @@ pub(super) fn ipc_file_info( #[cfg(feature = "csv")] pub(super) fn csv_file_info( - paths: &[PathBuf], + sources: &ScanSources, file_options: &FileScanOptions, csv_options: &mut CsvReadOptions, cloud_options: Option<&polars_io::cloud::CloudOptions>, ) -> PolarsResult { use std::io::{Read, Seek}; + use polars_core::error::feature_gated; use polars_core::{config, POOL}; use polars_io::csv::read::schema_inference::SchemaInferenceResult; use polars_io::utils::get_reader_bytes; use rayon::iter::{IntoParallelIterator, ParallelIterator}; + polars_ensure!(!sources.is_empty(), ComputeError: "expected at least 1 source"); + // TODO: // * See if we can do better than scanning all files if there is a row limit // * See if we can do this without downloading the entire file // prints the error message if paths is empty. - let first_path = get_first_path(paths)?; - let run_async = is_cloud_url(first_path) || config::force_async(); + let run_async = sources.is_cloud_url() || (sources.is_paths() && config::force_async()); let cache_entries = { - #[cfg(feature = "cloud")] - { - if run_async { + if run_async { + feature_gated!("cloud", { Some(polars_io::file_cache::init_entries_from_uri_list( - paths + sources + .as_paths() + .unwrap() .iter() .map(|path| Arc::from(path.to_str().unwrap())) .collect::>() .as_slice(), cloud_options, )?) - } else { - None - } - } - #[cfg(not(feature = "cloud"))] - { - if run_async { - panic!("required feature `cloud` is not enabled") - } + }) + } else { + None } }; let infer_schema_func = |i| { - let file = if run_async { - #[cfg(feature = "cloud")] - { - let entry: &Arc = - &cache_entries.as_ref().unwrap()[i]; - entry.try_open_check_latest()? - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") - } - } else { - let p: &PathBuf = &paths[i]; - polars_utils::open_file(p.as_ref())? - }; - - let mmap = unsafe { memmap::Mmap::map(&file).unwrap() }; + let source = sources.at(i); + let memslice = source.to_memslice_possibly_async(run_async, cache_entries.as_ref(), i)?; let owned = &mut vec![]; - - let mut curs = std::io::Cursor::new(maybe_decompress_bytes(mmap.as_ref(), owned)?); - - if curs.read(&mut [0; 4])? < 2 && csv_options.raise_if_empty { + let mut reader = std::io::Cursor::new(maybe_decompress_bytes(&memslice, owned)?); + if reader.read(&mut [0; 4])? < 2 && csv_options.raise_if_empty { polars_bail!(NoData: "empty CSV") } - curs.rewind()?; + reader.rewind()?; - let reader_bytes = get_reader_bytes(&mut curs).expect("could not mmap file"); + let reader_bytes = get_reader_bytes(&mut reader).expect("could not mmap file"); // this needs a way to estimated bytes/rows. - let si_result = - SchemaInferenceResult::try_from_reader_bytes_and_options(&reader_bytes, csv_options)?; - - Ok(si_result) + SchemaInferenceResult::try_from_reader_bytes_and_options(&reader_bytes, csv_options) }; let merge_func = |a: PolarsResult, - b: PolarsResult| match (a, b) { - (Err(e), _) | (_, Err(e)) => Err(e), - (Ok(a), Ok(b)) => { - let merged_schema = if csv_options.schema.is_some() { - csv_options.schema.clone().unwrap() - } else { - let schema_a = a.get_inferred_schema(); - let schema_b = b.get_inferred_schema(); - - match (schema_a.is_empty(), schema_b.is_empty()) { - (true, _) => schema_b, - (_, true) => schema_a, - _ => { - let mut s = Arc::unwrap_or_clone(schema_a); - s.to_supertype(&schema_b)?; - Arc::new(s) - }, - } - }; - - Ok(a.with_inferred_schema(merged_schema)) - }, + b: PolarsResult| { + match (a, b) { + (Err(e), _) | (_, Err(e)) => Err(e), + (Ok(a), Ok(b)) => { + let merged_schema = if csv_options.schema.is_some() { + csv_options.schema.clone().unwrap() + } else { + let schema_a = a.get_inferred_schema(); + let schema_b = b.get_inferred_schema(); + + match (schema_a.is_empty(), schema_b.is_empty()) { + (true, _) => schema_b, + (_, true) => schema_a, + _ => { + let mut s = Arc::unwrap_or_clone(schema_a); + s.to_supertype(&schema_b)?; + Arc::new(s) + }, + } + }; + + Ok(a.with_inferred_schema(merged_schema)) + }, + } }; let si_results = POOL.join( || infer_schema_func(0), || { - (1..paths.len()) + (1..sources.len()) .into_par_iter() .map(infer_schema_func) .reduce(|| Ok(Default::default()), merge_func) @@ -254,7 +238,7 @@ pub(super) fn csv_file_info( let reader_schema = if let Some(rc) = &file_options.row_index { let reader_schema = schema.clone(); let mut output_schema = (*reader_schema).clone(); - output_schema.insert_at_index(0, rc.name.as_ref().into(), IDX_DTYPE)?; + output_schema.insert_at_index(0, rc.name.clone(), IDX_DTYPE)?; schema = Arc::new(output_schema); reader_schema } else { @@ -272,58 +256,40 @@ pub(super) fn csv_file_info( #[cfg(feature = "json")] pub(super) fn ndjson_file_info( - paths: &[PathBuf], + sources: &ScanSources, file_options: &FileScanOptions, ndjson_options: &mut NDJsonReadOptions, cloud_options: Option<&polars_io::cloud::CloudOptions>, ) -> PolarsResult { use polars_core::config; + use polars_core::error::feature_gated; - let run_async = !paths.is_empty() && is_cloud_url(&paths[0]) || config::force_async(); + let Some(first) = sources.first() else { + polars_bail!(ComputeError: "expected at least 1 source"); + }; + + let run_async = sources.is_cloud_url() || (sources.is_paths() && config::force_async()); let cache_entries = { - #[cfg(feature = "cloud")] - { - if run_async { + if run_async { + feature_gated!("cloud", { Some(polars_io::file_cache::init_entries_from_uri_list( - paths + sources + .as_paths() + .unwrap() .iter() .map(|path| Arc::from(path.to_str().unwrap())) .collect::>() .as_slice(), cloud_options, )?) - } else { - None - } - } - #[cfg(not(feature = "cloud"))] - { - if run_async { - panic!("required feature `cloud` is not enabled") - } - } - }; - - let first_path = get_first_path(paths)?; - - let f = if run_async { - #[cfg(feature = "cloud")] - { - cache_entries.unwrap()[0].try_open_check_latest()? - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") + }) + } else { + None } - } else { - polars_utils::open_file(first_path)? }; let owned = &mut vec![]; - let mmap = unsafe { memmap::Mmap::map(&f).unwrap() }; - - let mut reader = std::io::BufReader::new(maybe_decompress_bytes(mmap.as_ref(), owned)?); let (mut reader_schema, schema) = if let Some(schema) = ndjson_options.schema.take() { if file_options.row_index.is_none() { @@ -335,8 +301,12 @@ pub(super) fn ndjson_file_info( ) } } else { + let memslice = first.to_memslice_possibly_async(run_async, cache_entries.as_ref(), 0)?; + let mut reader = std::io::Cursor::new(maybe_decompress_bytes(&memslice, owned)?); + let schema = polars_io::ndjson::infer_schema(&mut reader, ndjson_options.infer_schema_length)?; + prepare_schemas(schema, file_options.row_index.as_ref()) }; diff --git a/crates/polars-plan/src/plans/conversion/stack_opt.rs b/crates/polars-plan/src/plans/conversion/stack_opt.rs index 6e05a872a8cf..8db4e82659d5 100644 --- a/crates/polars-plan/src/plans/conversion/stack_opt.rs +++ b/crates/polars-plan/src/plans/conversion/stack_opt.rs @@ -7,6 +7,12 @@ pub(super) struct ConversionOptimizer { scratch: Vec, simplify: Option, coerce: Option, + // IR's can be cached in the DSL. + // But if they are used multiple times in DSL (e.g. concat/join) + // then it can occur that we take a slot multiple times. + // So we keep track of the arena versions used and allow only + // one unique IR cache to be reused. + pub(super) used_arenas: PlHashSet, } impl ConversionOptimizer { @@ -27,6 +33,7 @@ impl ConversionOptimizer { scratch: Vec::with_capacity(8), simplify, coerce, + used_arenas: Default::default(), } } diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 4f8dd1ee0fb0..7ee2282b0da9 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -60,7 +60,7 @@ fn process_list_arithmetic( if type_right != **inner { let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, - data_type: *inner.clone(), + dtype: *inner.clone(), options: CastOptions::NonStrict, }); @@ -77,7 +77,7 @@ fn process_list_arithmetic( if type_left != **inner { let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, - data_type: *inner.clone(), + dtype: *inner.clone(), options: CastOptions::NonStrict, }); @@ -110,7 +110,7 @@ fn process_struct_numeric_arithmetic( if let Some(first) = fields.first() { let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, - data_type: DataType::Struct(vec![first.clone()]), + dtype: DataType::Struct(vec![first.clone()]), options: CastOptions::NonStrict, }); Ok(Some(AExpr::BinaryExpr { @@ -126,7 +126,7 @@ fn process_struct_numeric_arithmetic( if let Some(first) = fields.first() { let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, - data_type: DataType::Struct(vec![first.clone()]), + dtype: DataType::Struct(vec![first.clone()]), options: CastOptions::NonStrict, }); @@ -296,14 +296,20 @@ pub(super) fn process_binary( st = String } - // only cast if the type is not already the super type. + // TODO! raise here? + // We should at least never cast to Unknown. + if matches!(st, DataType::Unknown(UnknownKind::Any)) { + return Ok(None); + } + + // Only cast if the type is not already the super type. // this can prevent an expensive flattening and subsequent aggregation // in a group_by context. To be able to cast the groups need to be // flattened let new_node_left = if type_left != st { expr_arena.add(AExpr::Cast { expr: node_left, - data_type: st.clone(), + dtype: st.clone(), options: CastOptions::NonStrict, }) } else { @@ -312,7 +318,7 @@ pub(super) fn process_binary( let new_node_right = if type_right != st { expr_arena.add(AExpr::Cast { expr: node_right, - data_type: st, + dtype: st, options: CastOptions::NonStrict, }) } else { diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs b/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs new file mode 100644 index 000000000000..c7b738722c55 --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/functions.rs @@ -0,0 +1,83 @@ +use either::Either; + +use super::*; + +pub(super) fn get_function_dtypes( + input: &[ExprIR], + expr_arena: &Arena, + input_schema: &Schema, + function: &FunctionExpr, + mut options: FunctionOptions, +) -> PolarsResult, AExpr>> { + let mut early_return = move || { + // Next iteration this will not hit anymore as options is updated. + options.cast_to_supertypes = None; + Ok(Either::Right(AExpr::Function { + function: function.clone(), + input: input.to_vec(), + options, + })) + }; + + let mut dtypes = Vec::with_capacity(input.len()); + let mut first = true; + for e in input { + let Some((_, dtype)) = get_aexpr_and_type(expr_arena, e.node(), input_schema) else { + return early_return(); + }; + + if first { + check_namespace(function, &dtype)?; + first = false; + } + // Ignore Unknown in the inputs. + // We will raise if we cannot find the supertype later. + match dtype { + DataType::Unknown(UnknownKind::Any) => { + return early_return(); + }, + _ => dtypes.push(dtype), + } + } + + if dtypes.iter().all_equal() { + return early_return(); + } + Ok(Either::Left(dtypes)) +} + +// `str` namespace belongs to `String` +// `cat` namespace belongs to `Categorical` etc. +fn check_namespace(function: &FunctionExpr, first_dtype: &DataType) -> PolarsResult<()> { + match function { + #[cfg(feature = "strings")] + FunctionExpr::StringExpr(_) => { + polars_ensure!(first_dtype == &DataType::String, InvalidOperation: "expected String type, got: {}", first_dtype) + }, + FunctionExpr::BinaryExpr(_) => { + polars_ensure!(first_dtype == &DataType::Binary, InvalidOperation: "expected Binary type, got: {}", first_dtype) + }, + #[cfg(feature = "temporal")] + FunctionExpr::TemporalExpr(_) => { + polars_ensure!(first_dtype.is_temporal(), InvalidOperation: "expected Date(time)/Duration type, got: {}", first_dtype) + }, + FunctionExpr::ListExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::List(_)), InvalidOperation: "expected List type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-array")] + FunctionExpr::ArrayExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::Array(_, _)), InvalidOperation: "expected Array type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(_) => { + polars_ensure!(matches!(first_dtype, DataType::Struct(_)), InvalidOperation: "expected Struct type, got: {}", first_dtype) + }, + #[cfg(feature = "dtype-categorical")] + FunctionExpr::Categorical(_) => { + polars_ensure!(matches!(first_dtype, DataType::Categorical(_, _)), InvalidOperation: "expected Struct type, got: {}", first_dtype) + }, + _ => {}, + } + + Ok(()) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs new file mode 100644 index 000000000000..34f54f6eb42e --- /dev/null +++ b/crates/polars-plan/src/plans/conversion/type_coercion/is_in.rs @@ -0,0 +1,97 @@ +use super::*; + +pub(super) fn resolve_is_in( + input: &[ExprIR], + expr_arena: &Arena, + lp_arena: &Arena, + lp_node: Node, +) -> PolarsResult> { + let input_schema = get_schema(lp_arena, lp_node); + let other_e = &input[1]; + let (_, type_left) = unpack!(get_aexpr_and_type( + expr_arena, + input[0].node(), + &input_schema + )); + let (_, type_other) = unpack!(get_aexpr_and_type( + expr_arena, + other_e.node(), + &input_schema + )); + + unpack!(early_escape(&type_left, &type_other)); + + let casted_expr = match (&type_left, &type_other) { + // types are equal, do nothing + (a, b) if a == b => return Ok(None), + // all-null can represent anything (and/or empty list), so cast to target dtype + (_, DataType::Null) => AExpr::Cast { + expr: other_e.node(), + dtype: type_left, + options: CastOptions::NonStrict, + }, + #[cfg(feature = "dtype-categorical")] + (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => return Ok(None), + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => return Ok(None), + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), dt) if dt.is_numeric() => AExpr::Cast { + expr: other_e.node(), + dtype: type_left, + options: CastOptions::NonStrict, + }, + #[cfg(feature = "dtype-decimal")] + (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + // can't check for more granular time_unit in less-granular time_unit data, + // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) + (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) + } + }, + (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { + if lhs_unit <= rhs_unit { + return Ok(None); + } else { + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) + } + }, + (_, DataType::List(other_inner)) => { + if other_inner.as_ref() == &type_left + || (type_left == DataType::Null) + || (other_inner.as_ref() == &DataType::Null) + || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) + { + return Ok(None); + } + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_left, &type_other) + }, + #[cfg(feature = "dtype-array")] + (_, DataType::Array(other_inner, _)) => { + if other_inner.as_ref() == &type_left + || (type_left == DataType::Null) + || (other_inner.as_ref() == &DataType::Null) + || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) + { + return Ok(None); + } + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_left, &type_other) + }, + #[cfg(feature = "dtype-struct")] + (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), + + // don't attempt to cast between obviously mismatched types, but + // allow integer/float comparison (will use their supertypes). + (a, b) => { + if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) { + return Ok(None); + } + polars_bail!(InvalidOperation: "'is_in' cannot check for {:?} values in {:?} data", &type_other, &type_left) + }, + }; + Ok(Some(casted_expr)) +} diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index 41ec0d3cb483..ceb3d7dffd49 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -1,13 +1,18 @@ mod binary; +mod functions; +#[cfg(feature = "is_in")] +mod is_in; use std::borrow::Cow; -use arrow::legacy::utils::CustomIterTools; +use arrow::temporal_conversions::{time_unit_multiple, SECONDS_IN_DAY}; use binary::process_binary; +use either::Either; use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int}; use polars_utils::idx_vec::UnitVec; +use polars_utils::itertools::Itertools; use polars_utils::{format_list, unitvec}; use super::*; @@ -22,6 +27,7 @@ macro_rules! unpack { } }; } +pub(super) use unpack; /// determine if we use the supertype or not. For instance when we have a column Int64 and we compare with literal UInt32 /// it would be wasteful to cast the column instead of the literal. @@ -120,14 +126,14 @@ impl OptimizationRule for TypeCoercionRule { let out = match *expr { AExpr::Cast { expr, - ref data_type, + ref dtype, options, } => { let input = expr_arena.get(expr); inline_or_prune_cast( input, - data_type, + dtype, options.strict(), lp_node, lp_arena, @@ -158,7 +164,7 @@ impl OptimizationRule for TypeCoercionRule { let new_node_truthy = if type_true != st { expr_arena.add(AExpr::Cast { expr: truthy_node, - data_type: st.clone(), + dtype: st.clone(), options: CastOptions::Strict, }) } else { @@ -168,7 +174,7 @@ impl OptimizationRule for TypeCoercionRule { let new_node_falsy = if type_false != st { expr_arena.add(AExpr::Cast { expr: falsy_node, - data_type: st, + dtype: st, options: CastOptions::Strict, }) } else { @@ -192,98 +198,12 @@ impl OptimizationRule for TypeCoercionRule { ref input, options, } => { - let input_schema = get_schema(lp_arena, lp_node); - let other_e = &input[1]; - let (_, type_left) = unpack!(get_aexpr_and_type( - expr_arena, - input[0].node(), - &input_schema - )); - let (_, type_other) = unpack!(get_aexpr_and_type( - expr_arena, - other_e.node(), - &input_schema - )); - - unpack!(early_escape(&type_left, &type_other)); - - let casted_expr = match (&type_left, &type_other) { - // types are equal, do nothing - (a, b) if a == b => return Ok(None), - // all-null can represent anything (and/or empty list), so cast to target dtype - (_, DataType::Null) => AExpr::Cast { - expr: other_e.node(), - data_type: type_left, - options: CastOptions::NonStrict, - }, - #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => { - return Ok(None) - }, - #[cfg(feature = "dtype-categorical")] - (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => { - return Ok(None) - }, - #[cfg(feature = "dtype-decimal")] - (DataType::Decimal(_, _), dt) if dt.is_numeric() => AExpr::Cast { - expr: other_e.node(), - data_type: type_left, - options: CastOptions::NonStrict, - }, - #[cfg(feature = "dtype-decimal")] - (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) - }, - // can't check for more granular time_unit in less-granular time_unit data, - // or we'll cast away valid/necessary precision (eg: nanosecs to millisecs) - (DataType::Datetime(lhs_unit, _), DataType::Datetime(rhs_unit, _)) => { - if lhs_unit <= rhs_unit { - return Ok(None); - } else { - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Datetime data", &rhs_unit, &lhs_unit) - } - }, - (DataType::Duration(lhs_unit), DataType::Duration(rhs_unit)) => { - if lhs_unit <= rhs_unit { - return Ok(None); - } else { - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} precision values in {:?} Duration data", &rhs_unit, &lhs_unit) - } - }, - (_, DataType::List(other_inner)) => { - if other_inner.as_ref() == &type_left - || (type_left == DataType::Null) - || (other_inner.as_ref() == &DataType::Null) - || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) - { - return Ok(None); - } - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_left, &type_other) - }, - #[cfg(feature = "dtype-array")] - (_, DataType::Array(other_inner, _)) => { - if other_inner.as_ref() == &type_left - || (type_left == DataType::Null) - || (other_inner.as_ref() == &DataType::Null) - || (other_inner.as_ref().is_numeric() && type_left.is_numeric()) - { - return Ok(None); - } - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_left, &type_other) - }, - #[cfg(feature = "dtype-struct")] - (DataType::Struct(_), _) | (_, DataType::Struct(_)) => return Ok(None), - - // don't attempt to cast between obviously mismatched types, but - // allow integer/float comparison (will use their supertypes). - (a, b) => { - if (a.is_numeric() && b.is_numeric()) || (a == &DataType::Null) { - return Ok(None); - } - polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) - }, + let Some(casted_expr) = is_in::resolve_is_in(input, expr_arena, lp_arena, lp_node)? + else { + return Ok(None); }; - let mut input = input.clone(); + + let mut input = input.to_vec(); let other_input = expr_arena.add(casted_expr); input[1].set_node(other_input); @@ -299,8 +219,6 @@ impl OptimizationRule for TypeCoercionRule { ref input, options, } => { - let mut input = input.clone(); - let input_schema = get_schema(lp_arena, lp_node); let left_node = input[0].node(); let fill_value_node = input[2].node(); @@ -318,10 +236,11 @@ impl OptimizationRule for TypeCoercionRule { let super_type = modify_supertype(super_type, left, fill_value, &type_left, &type_fill_value); + let mut input = input.clone(); let new_node_left = if type_left != super_type { expr_arena.add(AExpr::Cast { expr: left_node, - data_type: super_type.clone(), + dtype: super_type.clone(), options: CastOptions::NonStrict, }) } else { @@ -331,7 +250,7 @@ impl OptimizationRule for TypeCoercionRule { let new_node_fill_value = if type_fill_value != super_type { expr_arena.add(AExpr::Cast { expr: fill_value_node, - data_type: super_type.clone(), + dtype: super_type.clone(), options: CastOptions::NonStrict, }) } else { @@ -355,25 +274,17 @@ impl OptimizationRule for TypeCoercionRule { mut options, } if options.cast_to_supertypes.is_some() => { let input_schema = get_schema(lp_arena, lp_node); - let mut dtypes = Vec::with_capacity(input.len()); - for e in input { - let (_, dtype) = - unpack!(get_aexpr_and_type(expr_arena, e.node(), &input_schema)); - // Ignore Unknown in the inputs. - // We will raise if we cannot find the supertype later. - match dtype { - DataType::Unknown(UnknownKind::Any) => { - options.cast_to_supertypes = None; - return Ok(None); - }, - _ => dtypes.push(dtype), - } - } - if dtypes.iter().all_equal() { - options.cast_to_supertypes = None; - return Ok(None); - } + let dtypes = match functions::get_function_dtypes( + input, + expr_arena, + &input_schema, + function, + options, + )? { + Either::Left(dtypes) => dtypes, + Either::Right(ae) => return Ok(Some(ae)), + }; // TODO! use args_to_supertype. let self_e = input[0].clone(); @@ -390,7 +301,8 @@ impl OptimizationRule for TypeCoercionRule { &type_other, options.cast_to_supertypes.unwrap(), ) else { - polars_bail!(InvalidOperation: "could not determine supertype of: {}", format_list!(dtypes)); + raise_supertype(function, input, &input_schema, expr_arena)?; + unreachable!() }; if input.len() == 2 { // modify_supertype is a bit more conservative of casting columns @@ -404,7 +316,8 @@ impl OptimizationRule for TypeCoercionRule { } if matches!(super_type, DataType::Unknown(UnknownKind::Any)) { - polars_bail!(InvalidOperation: "could not determine supertype of: {}", format_list!(dtypes)); + raise_supertype(function, input, &input_schema, expr_arena)?; + unreachable!() } let function = function.clone(); @@ -431,7 +344,7 @@ impl OptimizationRule for TypeCoercionRule { if dtype != super_type { let n = expr_arena.add(AExpr::Cast { expr: e.node(), - data_type: super_type.clone(), + dtype: super_type.clone(), options: CastOptions::NonStrict, }); e.set_node(n); @@ -454,10 +367,10 @@ impl OptimizationRule for TypeCoercionRule { let input_schema = get_schema(lp_arena, lp_node); let (_, offset_dtype) = unpack!(get_aexpr_and_type(expr_arena, offset, &input_schema)); - polars_ensure!(offset_dtype.is_integer(), InvalidOperation: "offset must be integral for slice, not {}", offset_dtype); + polars_ensure!(offset_dtype.is_integer(), InvalidOperation: "offset must be integral for slice expression, not {}", offset_dtype); let (_, length_dtype) = unpack!(get_aexpr_and_type(expr_arena, length, &input_schema)); - polars_ensure!(length_dtype.is_integer() || length_dtype.is_null(), InvalidOperation: "length must be integral for slice, not {}", length_dtype); + polars_ensure!(length_dtype.is_integer() || length_dtype.is_null(), InvalidOperation: "length must be integral for slice expression, not {}", length_dtype); None }, _ => None, @@ -508,6 +421,13 @@ fn inline_or_prune_cast( let av = AnyValue::String(s).strict_cast(dtype); return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap()))); }, + // We generate casted literal datetimes, so ensure we cast upon conversion + // to create simpler expr trees. + #[cfg(feature = "dtype-datetime")] + LiteralValue::DateTime(ts, tu, None) if dtype.is_date() => { + let from_size = time_unit_multiple(tu.to_arrow()) * SECONDS_IN_DAY; + LiteralValue::Date((*ts / from_size) as i32) + }, lv @ (LiteralValue::Int(_) | LiteralValue::Float(_)) => { let av = lv.to_any_value().ok_or_else(|| polars_err!(InvalidOperation: "literal value: {:?} too large for Polars", lv))?; let av = av.strict_cast(dtype); @@ -566,6 +486,37 @@ fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { } } +fn raise_supertype( + function: &FunctionExpr, + inputs: &[ExprIR], + input_schema: &Schema, + expr_arena: &Arena, +) -> PolarsResult<()> { + let dtypes = inputs + .iter() + .map(|e| { + let ae = expr_arena.get(e.node()); + ae.to_dtype(input_schema, Context::Default, expr_arena) + }) + .collect::>>()?; + + let st = dtypes + .iter() + .cloned() + .map(Some) + .reduce(|a, b| get_supertype(&a?, &b?)) + .expect("always at least 2 inputs"); + // We could get a supertype with the default options, so the input types are not allowed for this + // specific operation. + if st.is_some() { + polars_bail!(InvalidOperation: "got invalid or ambiguous dtypes: '{}' in expression '{}'\ + \n\nConsider explicitly casting your input types to resolve potential ambiguity.", format_list!(&dtypes), function); + } else { + polars_bail!(InvalidOperation: "could not determine supertype of: {} in expression '{}'\ + \n\nIt might also be the case that the type combination isn't allowed in this specific operation.", format_list!(&dtypes), function); + } +} + #[cfg(test)] #[cfg(feature = "dtype-categorical")] mod test { @@ -581,7 +532,7 @@ mod test { let rules: &mut [Box] = &mut [Box::new(TypeCoercionRule {})]; let df = DataFrame::new(Vec::from([Series::new_empty( - "fruits", + PlSmallStr::from_static("fruits"), &DataType::Categorical(None, Default::default()), )])) .unwrap(); @@ -591,7 +542,8 @@ mod test { .project(expr_in.clone(), Default::default()) .build(); - let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena, true, true).unwrap(); + let mut lp_top = + to_alp(lp, &mut expr_arena, &mut lp_arena, &mut OptFlags::default()).unwrap(); lp_top = optimizer .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) .unwrap(); @@ -606,7 +558,8 @@ mod test { let lp = DslBuilder::from_existing_df(df) .project(expr_in, Default::default()) .build(); - let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena, true, true).unwrap(); + let mut lp_top = + to_alp(lp, &mut expr_arena, &mut lp_arena, &mut OptFlags::default()).unwrap(); lp_top = optimizer .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) .unwrap(); diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index 1161406a44b9..8512fdc8d8ea 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -3,27 +3,32 @@ use std::hash::Hash; #[cfg(feature = "cse")] use std::hash::Hasher; +use polars_utils::format_pl_smallstr; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; + use super::*; -use crate::constants::{get_len_name, LITERAL_NAME}; +use crate::constants::{get_len_name, get_literal_name}; #[derive(Default, Debug, Clone, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub enum OutputName { /// No not yet set. #[default] None, /// The most left-hand-side literal will be the output name. - LiteralLhs(ColumnName), + LiteralLhs(PlSmallStr), /// The most left-hand-side column will be the output name. - ColumnLhs(ColumnName), - /// Rename the output as `ColumnName`. - Alias(ColumnName), + ColumnLhs(PlSmallStr), + /// Rename the output as `PlSmallStr`. + Alias(PlSmallStr), #[cfg(feature = "dtype-struct")] /// A struct field. - Field(ColumnName), + Field(PlSmallStr), } impl OutputName { - fn unwrap(&self) -> &ColumnName { + pub fn unwrap(&self) -> &PlSmallStr { match self { OutputName::Alias(name) => name, OutputName::ColumnLhs(name) => name, @@ -40,6 +45,7 @@ impl OutputName { } #[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub struct ExprIR { /// Output name of this expression. output_name: OutputName, @@ -74,9 +80,9 @@ impl ExprIR { }, AExpr::Literal(lv) => { if let LiteralValue::Series(s) = lv { - out.output_name = OutputName::LiteralLhs(s.name().into()); + out.output_name = OutputName::LiteralLhs(s.name().clone()); } else { - out.output_name = OutputName::LiteralLhs(LITERAL_NAME.into()); + out.output_name = OutputName::LiteralLhs(get_literal_name().clone()); } break; }, @@ -90,9 +96,8 @@ impl ExprIR { }, _ => { if input.is_empty() { - out.output_name = OutputName::LiteralLhs(ColumnName::from( - format!("{}", function), - )); + out.output_name = + OutputName::LiteralLhs(format_pl_smallstr!("{}", function)); } else { out.output_name = input[0].output_name.clone(); } @@ -102,7 +107,8 @@ impl ExprIR { }, AExpr::AnonymousFunction { input, options, .. } => { if input.is_empty() { - out.output_name = OutputName::LiteralLhs(ColumnName::from(options.fmt_str)); + out.output_name = + OutputName::LiteralLhs(PlSmallStr::from_static(options.fmt_str)); } else { out.output_name = input[0].output_name.clone(); } @@ -142,32 +148,28 @@ impl ExprIR { } #[cfg(feature = "cse")] - pub(crate) fn set_alias(&mut self, name: ColumnName) { + pub(crate) fn set_alias(&mut self, name: PlSmallStr) { self.output_name = OutputName::Alias(name) } - pub(crate) fn output_name_inner(&self) -> &OutputName { + pub fn output_name_inner(&self) -> &OutputName { &self.output_name } - pub(crate) fn output_name_arc(&self) -> &Arc { + pub fn output_name(&self) -> &PlSmallStr { self.output_name.unwrap() } - pub fn output_name(&self) -> &str { - self.output_name_arc().as_ref() - } - pub fn to_expr(&self, expr_arena: &Arena) -> Expr { let out = node_to_expr(self.node, expr_arena); match &self.output_name { - OutputName::Alias(name) => out.alias(name.as_ref()), + OutputName::Alias(name) => out.alias(name.clone()), _ => out, } } - pub fn get_alias(&self) -> Option<&ColumnName> { + pub fn get_alias(&self) -> Option<&PlSmallStr> { match &self.output_name { OutputName::Alias(name) => Some(name), _ => None, @@ -175,7 +177,7 @@ impl ExprIR { } /// Gets any name except one deriving from `Column`. - pub(crate) fn get_non_projected_name(&self) -> Option<&ColumnName> { + pub(crate) fn get_non_projected_name(&self) -> Option<&PlSmallStr> { match &self.output_name { OutputName::Alias(name) => Some(name), #[cfg(feature = "dtype-struct")] @@ -227,20 +229,20 @@ impl From<&ExprIR> for Node { } } -pub(crate) fn name_to_expr_ir(name: &str, expr_arena: &mut Arena) -> ExprIR { - let name = ColumnName::from(name); +pub(crate) fn name_to_expr_ir(name: PlSmallStr, expr_arena: &mut Arena) -> ExprIR { let node = expr_arena.add(AExpr::Column(name.clone())); ExprIR::new(node, OutputName::ColumnLhs(name)) } -pub(crate) fn names_to_expr_irs, S: AsRef>( - names: I, - expr_arena: &mut Arena, -) -> Vec { +pub(crate) fn names_to_expr_irs(names: I, expr_arena: &mut Arena) -> Vec +where + I: IntoIterator, + S: Into, +{ names .into_iter() .map(|name| { - let name = name.as_ref(); + let name = name.into(); name_to_expr_ir(name, expr_arena) }) .collect() diff --git a/crates/polars-plan/src/plans/format.rs b/crates/polars-plan/src/plans/format.rs index 980fbc883756..d39f3dd35cc9 100644 --- a/crates/polars-plan/src/plans/format.rs +++ b/crates/polars-plan/src/plans/format.rs @@ -124,13 +124,13 @@ impl fmt::Debug for Expr { }, Cast { expr, - data_type, + dtype, options, } => { if options.strict() { - write!(f, "{expr:?}.strict_cast({data_type:?})") + write!(f, "{expr:?}.strict_cast({dtype:?})") } else { - write!(f, "{expr:?}.cast({data_type:?})") + write!(f, "{expr:?}.cast({dtype:?})") } }, Ternary { diff --git a/crates/polars-plan/src/plans/functions/count.rs b/crates/polars-plan/src/plans/functions/count.rs index fd92fdd9fc9d..7375ff47ff31 100644 --- a/crates/polars-plan/src/plans/functions/count.rs +++ b/crates/polars-plan/src/plans/functions/count.rs @@ -1,22 +1,35 @@ #[cfg(feature = "ipc")] use arrow::io::ipc::read::get_row_count as count_rows_ipc_sync; -#[cfg(feature = "parquet")] +#[cfg(any( + feature = "parquet", + feature = "ipc", + feature = "json", + feature = "csv" +))] +use polars_core::error::feature_gated; +#[cfg(any(feature = "parquet", feature = "json"))] use polars_io::cloud::CloudOptions; #[cfg(feature = "csv")] -use polars_io::csv::read::count_rows as count_rows_csv; +use polars_io::csv::read::{ + count_rows as count_rows_csv, count_rows_from_slice as count_rows_csv_from_slice, +}; #[cfg(all(feature = "parquet", feature = "cloud"))] use polars_io::parquet::read::ParquetAsyncReader; #[cfg(feature = "parquet")] use polars_io::parquet::read::ParquetReader; #[cfg(all(feature = "parquet", feature = "async"))] use polars_io::pl_async::{get_runtime, with_concurrency_budget}; -#[cfg(any(feature = "parquet", feature = "ipc"))] -use polars_io::{path_utils::is_cloud_url, SerReader}; +#[cfg(any(feature = "json", feature = "parquet"))] +use polars_io::SerReader; use super::*; #[allow(unused_variables)] -pub fn count_rows(paths: &Arc>, scan_type: &FileScan) -> PolarsResult { +pub fn count_rows( + sources: &ScanSources, + scan_type: &FileScan, + alias: Option, +) -> PolarsResult { #[cfg(not(any( feature = "parquet", feature = "ipc", @@ -39,26 +52,10 @@ pub fn count_rows(paths: &Arc>, scan_type: &FileScan) -> PolarsResu FileScan::Csv { options, cloud_options, - } => { - let parse_options = options.get_parse_options(); - let n_rows: PolarsResult = paths - .iter() - .map(|path| { - count_rows_csv( - path, - parse_options.separator, - parse_options.quote_char, - parse_options.comment_prefix.as_ref(), - parse_options.eol_char, - options.has_header, - ) - }) - .sum(); - n_rows - }, + } => count_all_rows_csv(sources, options), #[cfg(feature = "parquet")] FileScan::Parquet { cloud_options, .. } => { - count_rows_parquet(paths, cloud_options.as_ref()) + count_rows_parquet(sources, cloud_options.as_ref()) }, #[cfg(feature = "ipc")] FileScan::Ipc { @@ -66,7 +63,7 @@ pub fn count_rows(paths: &Arc>, scan_type: &FileScan) -> PolarsResu cloud_options, metadata, } => count_rows_ipc( - paths, + sources, #[cfg(feature = "cloud")] cloud_options.as_ref(), metadata.as_ref(), @@ -75,7 +72,7 @@ pub fn count_rows(paths: &Arc>, scan_type: &FileScan) -> PolarsResu FileScan::NDJson { options, cloud_options, - } => count_rows_ndjson(paths, cloud_options.as_ref()), + } => count_rows_ndjson(sources, cloud_options.as_ref()), FileScan::Anonymous { .. } => { unreachable!() }, @@ -84,34 +81,67 @@ pub fn count_rows(paths: &Arc>, scan_type: &FileScan) -> PolarsResu let count: IdxSize = count.try_into().map_err( |_| polars_err!(ComputeError: "count of {} exceeded maximum row size", count), )?; - DataFrame::new(vec![Series::new(crate::constants::LEN, [count])]) + let column_name = alias.unwrap_or(PlSmallStr::from_static(crate::constants::LEN)); + DataFrame::new(vec![Series::new(column_name, [count])]) } } + +#[cfg(feature = "csv")] +fn count_all_rows_csv( + sources: &ScanSources, + options: &polars_io::prelude::CsvReadOptions, +) -> PolarsResult { + let parse_options = options.get_parse_options(); + + sources + .iter() + .map(|source| match source { + ScanSourceRef::Path(path) => count_rows_csv( + path, + parse_options.separator, + parse_options.quote_char, + parse_options.comment_prefix.as_ref(), + parse_options.eol_char, + options.has_header, + ), + _ => { + let memslice = source.to_memslice()?; + + count_rows_csv_from_slice( + &memslice[..], + parse_options.separator, + parse_options.quote_char, + parse_options.comment_prefix.as_ref(), + parse_options.eol_char, + options.has_header, + ) + }, + }) + .sum() +} + #[cfg(feature = "parquet")] pub(super) fn count_rows_parquet( - paths: &Arc>, - cloud_options: Option<&CloudOptions>, + sources: &ScanSources, + #[allow(unused)] cloud_options: Option<&CloudOptions>, ) -> PolarsResult { - if paths.is_empty() { + if sources.is_empty() { return Ok(0); }; - let is_cloud = is_cloud_url(paths.first().unwrap().as_path()); + let is_cloud = sources.is_cloud_url(); if is_cloud { - #[cfg(not(feature = "cloud"))] - panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); - - #[cfg(feature = "cloud")] - { - get_runtime().block_on(count_rows_cloud_parquet(paths, cloud_options)) - } + feature_gated!("cloud", { + get_runtime().block_on(count_rows_cloud_parquet( + sources.as_paths().unwrap(), + cloud_options, + )) + }) } else { - paths + sources .iter() - .map(|path| { - let file = polars_utils::open_file(path)?; - let mut reader = ParquetReader::new(file); - reader.num_rows() + .map(|source| { + ParquetReader::new(std::io::Cursor::new(source.to_memslice()?)).num_rows() }) .sum::>() } @@ -119,7 +149,7 @@ pub(super) fn count_rows_parquet( #[cfg(all(feature = "parquet", feature = "async"))] async fn count_rows_cloud_parquet( - paths: &Arc>, + paths: &[std::path::PathBuf], cloud_options: Option<&CloudOptions>, ) -> PolarsResult { let collection = paths.iter().map(|path| { @@ -136,37 +166,37 @@ async fn count_rows_cloud_parquet( #[cfg(feature = "ipc")] pub(super) fn count_rows_ipc( - paths: &Arc>, + sources: &ScanSources, #[cfg(feature = "cloud")] cloud_options: Option<&CloudOptions>, metadata: Option<&arrow::io::ipc::read::FileMetadata>, ) -> PolarsResult { - if paths.is_empty() { + if sources.is_empty() { return Ok(0); }; - let is_cloud = is_cloud_url(paths.first().unwrap().as_path()); + let is_cloud = sources.is_cloud_url(); if is_cloud { - #[cfg(not(feature = "cloud"))] - panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); - - #[cfg(feature = "cloud")] - { - get_runtime().block_on(count_rows_cloud_ipc(paths, cloud_options, metadata)) - } + feature_gated!("cloud", { + get_runtime().block_on(count_rows_cloud_ipc( + sources.as_paths().unwrap(), + cloud_options, + metadata, + )) + }) } else { - paths + sources .iter() - .map(|path| { - let mut reader = polars_utils::open_file(path)?; - count_rows_ipc_sync(&mut reader).map(|v| v as usize) + .map(|source| { + let memslice = source.to_memslice()?; + count_rows_ipc_sync(&mut std::io::Cursor::new(memslice)).map(|v| v as usize) }) - .sum() + .sum::>() } } #[cfg(all(feature = "ipc", feature = "async"))] async fn count_rows_cloud_ipc( - paths: &Arc>, + paths: &[std::path::PathBuf], cloud_options: Option<&CloudOptions>, metadata: Option<&arrow::io::ipc::read::FileMetadata>, ) -> PolarsResult { @@ -185,55 +215,48 @@ async fn count_rows_cloud_ipc( #[cfg(feature = "json")] pub(super) fn count_rows_ndjson( - paths: &Arc>, + sources: &ScanSources, cloud_options: Option<&CloudOptions>, ) -> PolarsResult { use polars_core::config; + use polars_io::utils::maybe_decompress_bytes; - let run_async = !paths.is_empty() && is_cloud_url(&paths[0]) || config::force_async(); + if sources.is_empty() { + return Ok(0); + } + + let is_cloud_url = sources.is_cloud_url(); + let run_async = is_cloud_url || (sources.is_paths() && config::force_async()); let cache_entries = { - #[cfg(feature = "cloud")] - { - if run_async { + if run_async { + feature_gated!("cloud", { Some(polars_io::file_cache::init_entries_from_uri_list( - paths + sources + .as_paths() + .unwrap() .iter() .map(|path| Arc::from(path.to_str().unwrap())) .collect::>() .as_slice(), cloud_options, )?) - } else { - None - } - } - #[cfg(not(feature = "cloud"))] - { - if run_async { - panic!("required feature `cloud` is not enabled") - } + }) + } else { + None } }; - (0..paths.len()) - .map(|i| { - let f = if run_async { - #[cfg(feature = "cloud")] - { - let entry: &Arc = - &cache_entries.as_ref().unwrap()[0]; - entry.try_open_check_latest()? - } - #[cfg(not(feature = "cloud"))] - { - panic!("required feature `cloud` is not enabled") - } - } else { - polars_utils::open_file(&paths[i])? - }; - - let reader = polars_io::ndjson::core::JsonLineReader::new(f); + sources + .iter() + .map(|source| { + let memslice = + source.to_memslice_possibly_async(run_async, cache_entries.as_ref(), 0)?; + + let owned = &mut vec![]; + let reader = polars_io::ndjson::core::JsonLineReader::new(std::io::Cursor::new( + maybe_decompress_bytes(&memslice[..], owned)?, + )); reader.count() }) .sum() diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index 80314787e83e..fd4b740af9df 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -1,25 +1,49 @@ +use strum_macros::IntoStaticStr; + use super::*; -use crate::plans::conversion::rewrite_projections; -// Except for Opaque functions, this only has the DSL name of the function. +#[cfg(feature = "python")] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone)] +pub struct OpaquePythonUdf { + pub function: PythonFunction, + pub schema: Option, + /// allow predicate pushdown optimizations + pub predicate_pd: bool, + /// allow projection pushdown optimizations + pub projection_pd: bool, + pub streamable: bool, + pub validate_output: bool, +} + +// Except for Opaque functions, this only has the DSL name of the function. +#[derive(Clone, IntoStaticStr)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] pub enum DslFunction { - FunctionNode(FunctionNode), + // Function that is already converted to IR. + #[cfg_attr(feature = "serde", serde(skip))] + FunctionIR(FunctionIR), + // This is both in DSL and IR because we want to be able to serialize it. + #[cfg(feature = "python")] + OpaquePython(OpaquePythonUdf), Explode { - columns: Vec, + columns: Vec, + allow_empty: bool, }, + #[cfg(feature = "pivot")] Unpivot { - args: UnpivotArgs, + args: UnpivotArgsDSL, }, RowIndex { - name: Arc, + name: PlSmallStr, offset: Option, }, Rename { - existing: Arc<[SmartString]>, - new: Arc<[SmartString]>, + existing: Arc<[PlSmallStr]>, + new: Arc<[PlSmallStr]>, }, + Unnest(Vec), Stats(StatsFunction), /// FillValue FillNan(Expr), @@ -56,53 +80,67 @@ pub enum StatsFunction { Max, } +pub(crate) fn validate_columns_in_input>( + columns: &[S], + input_schema: &Schema, + operation_name: &str, +) -> PolarsResult<()> { + for c in columns { + polars_ensure!(input_schema.contains(c.as_ref()), ColumnNotFound: "'{}' on column: '{}' is invalid\n\nSchema at this point: {:?}", operation_name, c.as_ref(), input_schema) + } + Ok(()) +} + impl DslFunction { - pub(crate) fn into_function_node(self, input_schema: &Schema) -> PolarsResult { + pub(crate) fn into_function_ir(self, input_schema: &Schema) -> PolarsResult { let function = match self { - DslFunction::Explode { columns } => { - let columns = rewrite_projections(columns, input_schema, &[])?; - // columns to string - let columns = columns - .iter() - .map(|e| { - let Expr::Column(name) = e else { - polars_bail!(InvalidOperation: "expected column expression") - }; - polars_ensure!(input_schema.contains(name), col_not_found = name); - Ok(name.clone()) - }) - .collect::]>>>()?; - FunctionNode::Explode { - columns, + #[cfg(feature = "pivot")] + DslFunction::Unpivot { args } => { + let on = expand_selectors(args.on, input_schema, &[])?; + let index = expand_selectors(args.index, input_schema, &[])?; + validate_columns_in_input(on.as_ref(), input_schema, "unpivot")?; + validate_columns_in_input(index.as_ref(), input_schema, "unpivot")?; + + let args = UnpivotArgsIR { + on: on.iter().cloned().collect(), + index: index.iter().cloned().collect(), + variable_name: args.variable_name.clone(), + value_name: args.value_name.clone(), + }; + + FunctionIR::Unpivot { + args: Arc::new(args), schema: Default::default(), } }, - DslFunction::Unpivot { args } => FunctionNode::Unpivot { - args: Arc::new(args), - schema: Default::default(), - }, - DslFunction::FunctionNode(func) => func, - DslFunction::RowIndex { name, offset } => FunctionNode::RowIndex { + DslFunction::FunctionIR(func) => func, + DslFunction::RowIndex { name, offset } => FunctionIR::RowIndex { name, offset, schema: Default::default(), }, DslFunction::Rename { existing, new } => { let swapping = new.iter().any(|name| input_schema.get(name).is_some()); + validate_columns_in_input(existing.as_ref(), input_schema, "rename")?; - // Check if the name exists. - for name in existing.iter() { - let _ = input_schema.try_get(name)?; - } - - FunctionNode::Rename { + FunctionIR::Rename { existing, new, swapping, schema: Default::default(), } }, - DslFunction::Stats(_) | DslFunction::FillNan(_) | DslFunction::Drop(_) => { + DslFunction::Unnest(selectors) => { + let columns = expand_selectors(selectors, input_schema, &[])?; + validate_columns_in_input(columns.as_ref(), input_schema, "explode")?; + FunctionIR::Unnest { columns } + }, + #[cfg(feature = "python")] + DslFunction::OpaquePython(inner) => FunctionIR::OpaquePython(inner), + DslFunction::Stats(_) + | DslFunction::FillNan(_) + | DslFunction::Drop(_) + | DslFunction::Explode { .. } => { // We should not reach this. panic!("impl error") }, @@ -121,20 +159,17 @@ impl Display for DslFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use DslFunction::*; match self { - FunctionNode(inner) => write!(f, "{inner}"), - Explode { .. } => write!(f, "EXPLODE"), - Unpivot { .. } => write!(f, "UNPIVOT"), - RowIndex { .. } => write!(f, "WITH ROW INDEX"), - Stats(_) => write!(f, "STATS"), - FillNan(_) => write!(f, "FILL NAN"), - Drop(_) => write!(f, "DROP"), - Rename { .. } => write!(f, "RENAME"), + FunctionIR(inner) => write!(f, "{inner}"), + v => { + let s: &str = v.into(); + write!(f, "{s}") + }, } } } -impl From for DslFunction { - fn from(value: FunctionNode) -> Self { - DslFunction::FunctionNode(value) +impl From for DslFunction { + fn from(value: FunctionIR) -> Self { + DslFunction::FunctionIR(value) } } diff --git a/crates/polars-plan/src/plans/functions/explode.rs b/crates/polars-plan/src/plans/functions/explode.rs index 0103ed5f2818..a5140d81103b 100644 --- a/crates/polars-plan/src/plans/functions/explode.rs +++ b/crates/polars-plan/src/plans/functions/explode.rs @@ -1,5 +1,5 @@ use super::*; -pub(super) fn explode_impl(df: DataFrame, columns: &[SmartString]) -> PolarsResult { +pub(super) fn explode_impl(df: DataFrame, columns: &[PlSmallStr]) -> PolarsResult { df.explode(columns) } diff --git a/crates/polars-plan/src/plans/functions/merge_sorted.rs b/crates/polars-plan/src/plans/functions/merge_sorted.rs index a20a85d68812..ffc9e1f04df6 100644 --- a/crates/polars-plan/src/plans/functions/merge_sorted.rs +++ b/crates/polars-plan/src/plans/functions/merge_sorted.rs @@ -10,7 +10,7 @@ pub(super) fn merge_sorted(df: &DataFrame, column: &str) -> PolarsResult PolarsResult, - /// allow predicate pushdown optimizations - predicate_pd: bool, - /// allow projection pushdown optimizations - projection_pd: bool, - streamable: bool, - validate_output: bool, - }, - #[cfg_attr(feature = "serde", serde(skip))] + OpaquePython(OpaquePythonUdf), + #[cfg_attr(feature = "ir_serde", serde(skip))] Opaque { function: Arc, schema: Option>, @@ -49,23 +43,22 @@ pub enum FunctionNode { projection_pd: bool, streamable: bool, // used for formatting - #[cfg_attr(feature = "serde", serde(skip))] - fmt_str: &'static str, + fmt_str: PlSmallStr, }, - Count { - paths: Arc>, + FastCount { + sources: ScanSources, scan_type: FileScan, - alias: Option>, + alias: Option, }, - #[cfg_attr(feature = "serde", serde(skip))] /// Streaming engine pipeline + #[cfg_attr(feature = "ir_serde", serde(skip))] Pipeline { function: Arc>, schema: SchemaRef, original: Option>, }, Unnest { - columns: Arc<[Arc]>, + columns: Arc<[PlSmallStr]>, }, Rechunk, // The two DataFrames are temporary concatenated @@ -75,43 +68,51 @@ pub enum FunctionNode { #[cfg(feature = "merge_sorted")] MergeSorted { // sorted column that serves as the key - column: Arc, + column: PlSmallStr, }, Rename { - existing: Arc<[SmartString]>, - new: Arc<[SmartString]>, + existing: Arc<[PlSmallStr]>, + new: Arc<[PlSmallStr]>, // A column name gets swapped with an existing column swapping: bool, - #[cfg_attr(feature = "serde", serde(skip))] + #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, }, Explode { - columns: Arc<[Arc]>, - #[cfg_attr(feature = "serde", serde(skip))] + columns: Arc<[PlSmallStr]>, + #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, }, + #[cfg(feature = "pivot")] Unpivot { - args: Arc, - #[cfg_attr(feature = "serde", serde(skip))] + args: Arc, + #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, }, RowIndex { - name: Arc, + name: PlSmallStr, // Might be cached. - #[cfg_attr(feature = "serde", serde(skip))] + #[cfg_attr(feature = "ir_serde", serde(skip))] schema: CachedSchema, offset: Option, }, } -impl Eq for FunctionNode {} +impl Eq for FunctionIR {} -impl PartialEq for FunctionNode { +impl PartialEq for FunctionIR { fn eq(&self, other: &Self) -> bool { - use FunctionNode::*; + use FunctionIR::*; match (self, other) { (Rechunk, Rechunk) => true, - (Count { paths: paths_l, .. }, Count { paths: paths_r, .. }) => paths_l == paths_r, + ( + FastCount { + sources: srcs_l, .. + }, + FastCount { + sources: srcs_r, .. + }, + ) => srcs_l == srcs_r, ( Rename { existing: existing_l, @@ -125,6 +126,7 @@ impl PartialEq for FunctionNode { }, ) => existing_l == existing_r && new_l == new_r, (Explode { columns: l, .. }, Explode { columns: r, .. }) => l == r, + #[cfg(feature = "pivot")] (Unpivot { args: l, .. }, Unpivot { args: r, .. }) => l == r, (RowIndex { name: l, .. }, RowIndex { name: r, .. }) => l == r, #[cfg(feature = "merge_sorted")] @@ -134,28 +136,28 @@ impl PartialEq for FunctionNode { } } -impl Hash for FunctionNode { +impl Hash for FunctionIR { fn hash(&self, state: &mut H) { std::mem::discriminant(self).hash(state); match self { #[cfg(feature = "python")] - FunctionNode::OpaquePython { .. } => {}, - FunctionNode::Opaque { fmt_str, .. } => fmt_str.hash(state), - FunctionNode::Count { - paths, + FunctionIR::OpaquePython { .. } => {}, + FunctionIR::Opaque { fmt_str, .. } => fmt_str.hash(state), + FunctionIR::FastCount { + sources, scan_type, alias, } => { - paths.hash(state); + sources.hash(state); scan_type.hash(state); alias.hash(state); }, - FunctionNode::Pipeline { .. } => {}, - FunctionNode::Unnest { columns } => columns.hash(state), - FunctionNode::Rechunk => {}, + FunctionIR::Pipeline { .. } => {}, + FunctionIR::Unnest { columns } => columns.hash(state), + FunctionIR::Rechunk => {}, #[cfg(feature = "merge_sorted")] - FunctionNode::MergeSorted { column } => column.hash(state), - FunctionNode::Rename { + FunctionIR::MergeSorted { column } => column.hash(state), + FunctionIR::Rename { existing, new, swapping: _, @@ -164,9 +166,10 @@ impl Hash for FunctionNode { existing.hash(state); new.hash(state); }, - FunctionNode::Explode { columns, schema: _ } => columns.hash(state), - FunctionNode::Unpivot { args, schema: _ } => args.hash(state), - FunctionNode::RowIndex { + FunctionIR::Explode { columns, schema: _ } => columns.hash(state), + #[cfg(feature = "pivot")] + FunctionIR::Unpivot { args, schema: _ } => args.hash(state), + FunctionIR::RowIndex { name, schema: _, offset, @@ -178,60 +181,62 @@ impl Hash for FunctionNode { } } -impl FunctionNode { +impl FunctionIR { /// Whether this function can run on batches of data at a time. pub fn is_streamable(&self) -> bool { - use FunctionNode::*; + use FunctionIR::*; match self { Rechunk | Pipeline { .. } => false, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => false, - Count { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, - Unpivot { args, .. } => args.streamable, + FastCount { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, Opaque { streamable, .. } => *streamable, #[cfg(feature = "python")] - OpaquePython { streamable, .. } => *streamable, + OpaquePython(OpaquePythonUdf { streamable, .. }) => *streamable, RowIndex { .. } => false, } } /// Whether this function will increase the number of rows pub fn expands_rows(&self) -> bool { - use FunctionNode::*; + use FunctionIR::*; match self { #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, - Explode { .. } | Unpivot { .. } => true, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, + Explode { .. } => true, _ => false, } } pub(crate) fn allow_predicate_pd(&self) -> bool { - use FunctionNode::*; + use FunctionIR::*; match self { Opaque { predicate_pd, .. } => *predicate_pd, #[cfg(feature = "python")] - OpaquePython { predicate_pd, .. } => *predicate_pd, - Rechunk | Unnest { .. } | Rename { .. } | Explode { .. } | Unpivot { .. } => true, + OpaquePython(OpaquePythonUdf { predicate_pd, .. }) => *predicate_pd, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, + Rechunk | Unnest { .. } | Rename { .. } | Explode { .. } => true, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, - RowIndex { .. } | Count { .. } => false, + RowIndex { .. } | FastCount { .. } => false, Pipeline { .. } => unimplemented!(), } } pub(crate) fn allow_projection_pd(&self) -> bool { - use FunctionNode::*; + use FunctionIR::*; match self { Opaque { projection_pd, .. } => *projection_pd, #[cfg(feature = "python")] - OpaquePython { projection_pd, .. } => *projection_pd, - Rechunk - | Count { .. } - | Unnest { .. } - | Rename { .. } - | Explode { .. } - | Unpivot { .. } => true, + OpaquePython(OpaquePythonUdf { projection_pd, .. }) => *projection_pd, + Rechunk | FastCount { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, + #[cfg(feature = "pivot")] + Unpivot { .. } => true, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, RowIndex { .. } => true, @@ -239,8 +244,8 @@ impl FunctionNode { } } - pub(crate) fn additional_projection_pd_columns(&self) -> Cow<[Arc]> { - use FunctionNode::*; + pub(crate) fn additional_projection_pd_columns(&self) -> Cow<[PlSmallStr]> { + use FunctionIR::*; match self { Unnest { columns } => Cow::Borrowed(columns.as_ref()), Explode { columns, .. } => Cow::Borrowed(columns.as_ref()), @@ -251,19 +256,21 @@ impl FunctionNode { } pub fn evaluate(&self, mut df: DataFrame) -> PolarsResult { - use FunctionNode::*; + use FunctionIR::*; match self { Opaque { function, .. } => function.call_udf(df), #[cfg(feature = "python")] - OpaquePython { + OpaquePython(OpaquePythonUdf { function, validate_output, schema, .. - } => python_udf::call_python_udf(function, df, *validate_output, schema.as_deref()), - Count { - paths, scan_type, .. - } => count::count_rows(paths, scan_type), + }) => python_udf::call_python_udf(function, df, *validate_output, schema.as_deref()), + FastCount { + sources, + scan_type, + alias, + } => count::count_rows(sources, scan_type, alias.clone()), Rechunk => { df.as_single_chunk_par(); Ok(df) @@ -271,14 +278,7 @@ impl FunctionNode { #[cfg(feature = "merge_sorted")] MergeSorted { column } => merge_sorted(&df, column.as_ref()), Unnest { columns: _columns } => { - #[cfg(feature = "dtype-struct")] - { - df.unnest(_columns.as_ref()) - } - #[cfg(not(feature = "dtype-struct"))] - { - panic!("activate feature 'dtype-struct'") - } + feature_gated!("dtype-struct", df.unnest(_columns.iter().cloned())) }, Pipeline { function, .. } => { // we use a global string cache here as streaming chunks all have different rev maps @@ -294,12 +294,14 @@ impl FunctionNode { } }, Rename { existing, new, .. } => rename::rename_impl(df, existing, new), - Explode { columns, .. } => df.explode(columns.as_ref()), + Explode { columns, .. } => df.explode(columns.iter().cloned()), + #[cfg(feature = "pivot")] Unpivot { args, .. } => { + use polars_ops::pivot::UnpivotDF; let args = (**args).clone(); df.unpivot2(args) }, - RowIndex { name, offset, .. } => df.with_row_index(name.as_ref(), *offset), + RowIndex { name, offset, .. } => df.with_row_index(name.clone(), *offset), } } @@ -317,28 +319,22 @@ impl FunctionNode { } } -impl Debug for FunctionNode { +impl Debug for FunctionIR { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{self}") } } -impl Display for FunctionNode { +impl Display for FunctionIR { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - use FunctionNode::*; + use FunctionIR::*; match self { Opaque { fmt_str, .. } => write!(f, "{fmt_str}"), - #[cfg(feature = "python")] - OpaquePython { .. } => write!(f, "python dataframe udf"), - Rechunk => write!(f, "RECHUNK"), - Count { .. } => write!(f, "FAST COUNT(*)"), Unnest { columns } => { write!(f, "UNNEST by:")?; let columns = columns.as_ref(); fmt_column_delimited(f, columns, "[", "]") }, - #[cfg(feature = "merge_sorted")] - MergeSorted { .. } => write!(f, "MERGE SORTED"), Pipeline { original, .. } => { if let Some(original) = original { let ir_display = original.as_ref().display(); @@ -351,10 +347,25 @@ impl Display for FunctionNode { write!(f, "STREAMING") } }, - Rename { .. } => write!(f, "RENAME"), - Explode { .. } => write!(f, "EXPLODE"), - Unpivot { .. } => write!(f, "UNPIVOT"), - RowIndex { .. } => write!(f, "WITH ROW INDEX"), + FastCount { + sources, + scan_type, + alias, + } => { + let scan_type: &str = scan_type.into(); + let default_column_name = PlSmallStr::from_static(crate::constants::LEN); + let alias = alias.as_ref().unwrap_or(&default_column_name); + + write!( + f, + "FAST COUNT ({scan_type}) {} as \"{alias}\"", + ScanSourcesDisplay(sources) + ) + }, + v => { + let s: &str = v.into(); + write!(f, "{s}") + }, } } } diff --git a/crates/polars-plan/src/plans/functions/rename.rs b/crates/polars-plan/src/plans/functions/rename.rs index fea6c2cc635c..7a58101e3731 100644 --- a/crates/polars-plan/src/plans/functions/rename.rs +++ b/crates/polars-plan/src/plans/functions/rename.rs @@ -2,8 +2,8 @@ use super::*; pub(super) fn rename_impl( mut df: DataFrame, - existing: &[SmartString], - new: &[SmartString], + existing: &[PlSmallStr], + new: &[PlSmallStr], ) -> PolarsResult { let positions = existing .iter() @@ -14,7 +14,7 @@ pub(super) fn rename_impl( // the column might be removed due to projection pushdown // so we only update if we can find it. if let Some(pos) = pos { - unsafe { df.get_columns_mut()[*pos].rename(name) }; + unsafe { df.get_columns_mut()[*pos].rename(name.clone()) }; } } // recreate dataframe so we check duplicates diff --git a/crates/polars-plan/src/plans/functions/schema.rs b/crates/polars-plan/src/plans/functions/schema.rs index 9d33f12e737b..79e2b8343230 100644 --- a/crates/polars-plan/src/plans/functions/schema.rs +++ b/crates/polars-plan/src/plans/functions/schema.rs @@ -1,17 +1,21 @@ +#[cfg(feature = "pivot")] use polars_core::utils::try_get_supertype; use super::*; +use crate::constants::get_len_name; -impl FunctionNode { +impl FunctionIR { pub(crate) fn clear_cached_schema(&self) { - use FunctionNode::*; + use FunctionIR::*; // We will likely add more branches later #[allow(clippy::single_match)] match self { - RowIndex { schema, .. } - | Explode { schema, .. } - | Rename { schema, .. } - | Unpivot { schema, .. } => { + #[cfg(feature = "pivot")] + Unpivot { schema, .. } => { + let mut guard = schema.lock().unwrap(); + *guard = None; + }, + RowIndex { schema, .. } | Explode { schema, .. } | Rename { schema, .. } => { let mut guard = schema.lock().unwrap(); *guard = None; }, @@ -23,7 +27,7 @@ impl FunctionNode { &self, input_schema: &'a SchemaRef, ) -> PolarsResult> { - use FunctionNode::*; + use FunctionIR::*; match self { Opaque { schema, .. } => match schema { None => Ok(Cow::Borrowed(input_schema)), @@ -33,19 +37,14 @@ impl FunctionNode { }, }, #[cfg(feature = "python")] - OpaquePython { schema, .. } => Ok(schema + OpaquePython(OpaquePythonUdf { schema, .. }) => Ok(schema .as_ref() .map(|schema| Cow::Owned(schema.clone())) .unwrap_or_else(|| Cow::Borrowed(input_schema))), Pipeline { schema, .. } => Ok(Cow::Owned(schema.clone())), - Count { alias, .. } => { + FastCount { alias, .. } => { let mut schema: Schema = Schema::with_capacity(1); - let name = SmartString::from( - alias - .as_ref() - .map(|alias| alias.as_ref()) - .unwrap_or(crate::constants::LEN), - ); + let name = alias.clone().unwrap_or_else(get_len_name); schema.insert_at_index(0, name, IDX_DTYPE)?; Ok(Cow::Owned(Arc::new(schema))) }, @@ -55,14 +54,12 @@ impl FunctionNode { { let mut new_schema = Schema::with_capacity(input_schema.len() * 2); for (name, dtype) in input_schema.iter() { - if _columns.iter().any(|item| item.as_ref() == name.as_str()) { + if _columns.iter().any(|item| item == name) { match dtype { DataType::Struct(flds) => { for fld in flds { - new_schema.with_column( - fld.name().clone(), - fld.data_type().clone(), - ); + new_schema + .with_column(fld.name().clone(), fld.dtype().clone()); } }, DataType::Unknown(_) => { @@ -94,10 +91,13 @@ impl FunctionNode { schema, .. } => rename_schema(input_schema, existing, new, schema), - RowIndex { schema, name, .. } => { - Ok(Cow::Owned(row_index_schema(schema, input_schema, name))) - }, + RowIndex { schema, name, .. } => Ok(Cow::Owned(row_index_schema( + schema, + input_schema, + name.clone(), + ))), Explode { schema, columns } => explode_schema(schema, input_schema, columns), + #[cfg(feature = "pivot")] Unpivot { schema, args } => unpivot_schema(args, schema, input_schema), } } @@ -106,14 +106,14 @@ impl FunctionNode { fn row_index_schema( cached_schema: &CachedSchema, input_schema: &SchemaRef, - name: &str, + name: PlSmallStr, ) -> SchemaRef { let mut guard = cached_schema.lock().unwrap(); if let Some(schema) = &*guard { return schema.clone(); } let mut schema = (**input_schema).clone(); - schema.insert_at_index(0, name.into(), IDX_DTYPE).unwrap(); + schema.insert_at_index(0, name, IDX_DTYPE).unwrap(); let schema_ref = Arc::new(schema); *guard = Some(schema_ref.clone()); schema_ref @@ -122,7 +122,7 @@ fn row_index_schema( fn explode_schema<'a>( cached_schema: &CachedSchema, schema: &'a Schema, - columns: &[Arc], + columns: &[PlSmallStr], ) -> PolarsResult> { let mut guard = cached_schema.lock().unwrap(); if let Some(schema) = &*guard { @@ -134,7 +134,7 @@ fn explode_schema<'a>( columns.iter().try_for_each(|name| { if let DataType::List(inner) = schema.try_get(name)? { let inner = *inner.clone(); - schema.with_column(name.as_ref().into(), inner); + schema.with_column(name.clone(), inner); }; PolarsResult::Ok(()) })?; @@ -143,8 +143,9 @@ fn explode_schema<'a>( Ok(Cow::Owned(schema)) } +#[cfg(feature = "pivot")] fn unpivot_schema<'a>( - args: &UnpivotArgs, + args: &UnpivotArgsIR, cached_schema: &CachedSchema, input_schema: &'a Schema, ) -> PolarsResult> { @@ -156,7 +157,7 @@ fn unpivot_schema<'a>( let mut new_schema = args .index .iter() - .map(|id| Ok(Field::new(id, input_schema.try_get(id)?.clone()))) + .map(|id| Ok(Field::new(id.clone(), input_schema.try_get(id)?.clone()))) .collect::>()?; let variable_name = args .variable_name @@ -196,8 +197,8 @@ fn unpivot_schema<'a>( fn rename_schema<'a>( input_schema: &'a SchemaRef, - existing: &[SmartString], - new: &[SmartString], + existing: &[PlSmallStr], + new: &[PlSmallStr], cached_schema: &CachedSchema, ) -> PolarsResult> { let mut guard = cached_schema.lock().unwrap(); diff --git a/crates/polars-plan/src/plans/hive.rs b/crates/polars-plan/src/plans/hive.rs index a89c8a32a127..3fc7531ea2b3 100644 --- a/crates/polars-plan/src/plans/hive.rs +++ b/crates/polars-plan/src/plans/hive.rs @@ -17,7 +17,7 @@ pub struct HivePartitions { impl HivePartitions { pub fn get_projection_schema_and_indices( &self, - names: &PlHashSet, + names: &PlHashSet, ) -> (SchemaRef, Vec) { let mut out_schema = Schema::with_capacity(self.stats.schema().len()); let mut out_indices = Vec::with_capacity(self.stats.column_stats().len()); @@ -114,7 +114,7 @@ pub fn hive_partitions_from_paths( dtype.clone() }; - Ok(Field::new(name, dtype)) + Ok(Field::new(PlSmallStr::from_str(name), dtype)) }).collect::>()?) } else { let mut hive_schema = Schema::with_capacity(16); diff --git a/crates/polars-plan/src/plans/ir/dot.rs b/crates/polars-plan/src/plans/ir/dot.rs index 49e9bef1a3dc..51050f2fa877 100644 --- a/crates/polars-plan/src/plans/ir/dot.rs +++ b/crates/polars-plan/src/plans/ir/dot.rs @@ -2,6 +2,7 @@ use std::fmt; use std::path::PathBuf; use polars_core::schema::Schema; +use polars_utils::pl_str::PlSmallStr; use super::format::ExprIRSliceDisplay; use crate::constants::UNLIMITED_CACHE; @@ -32,9 +33,9 @@ impl fmt::Display for DotNode { #[inline(always)] fn write_label<'a, 'b>( - f: &'b mut fmt::Formatter<'a>, + f: &'a mut fmt::Formatter<'b>, id: DotNode, - mut w: impl FnMut(&mut EscapeLabel<'a, 'b>) -> fmt::Result, + mut w: impl FnMut(&mut EscapeLabel<'a>) -> fmt::Result, ) -> fmt::Result { write!(f, "{INDENT}{id}[label=\"")?; @@ -246,7 +247,7 @@ impl<'a> IRDotDisplay<'a> { })?; }, Scan { - paths, + sources, file_info, hive_parts: _, predicate, @@ -255,7 +256,7 @@ impl<'a> IRDotDisplay<'a> { output_schema: _, } => { let name: &str = scan_type.into(); - let path = PathsDisplay(paths.as_ref()); + let path = ScanSourcesDisplay(sources); let with_columns = options.with_columns.as_ref().map(|cols| cols.as_ref()); let with_columns = NumColumns(with_columns); let total_columns = @@ -341,11 +342,38 @@ impl<'a> IRDotDisplay<'a> { } // A few utility structures for formatting -pub(crate) struct PathsDisplay<'a>(pub &'a [PathBuf]); -struct NumColumns<'a>(Option<&'a [String]>); +pub struct PathsDisplay<'a>(pub &'a [PathBuf]); +pub struct ScanSourcesDisplay<'a>(pub &'a ScanSources); +struct NumColumns<'a>(Option<&'a [PlSmallStr]>); struct NumColumnsSchema<'a>(Option<&'a Schema>); struct OptionExprIRDisplay<'a>(Option>); +impl fmt::Display for ScanSourceRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ScanSourceRef::Path(path) => path.display().fmt(f), + ScanSourceRef::File(_) => f.write_str("open-file"), + ScanSourceRef::Buffer(buff) => write!(f, "{} in-mem bytes", buff.len()), + } + } +} + +impl fmt::Display for ScanSourcesDisplay<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0.len() { + 0 => write!(f, "[]"), + 1 => write!(f, "[{}]", self.0.at(0)), + 2 => write!(f, "[{}, {}]", self.0.at(0), self.0.at(1)), + _ => write!( + f, + "[{}, ... {} other sources]", + self.0.at(0), + self.0.len() - 1, + ), + } + } +} + impl fmt::Display for PathsDisplay<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.0.len() { @@ -355,7 +383,7 @@ impl fmt::Display for PathsDisplay<'_> { _ => write!( f, "[{}, ... {} other files]", - self.0[0].to_string_lossy(), + self.0[0].display(), self.0.len() - 1, ), } @@ -390,9 +418,9 @@ impl fmt::Display for OptionExprIRDisplay<'_> { } /// Utility structure to write to a [`fmt::Formatter`] whilst escaping the output as a label name -struct EscapeLabel<'a, 'b>(&'b mut fmt::Formatter<'a>); +pub struct EscapeLabel<'a>(pub &'a mut dyn fmt::Write); -impl<'a, 'b> fmt::Write for EscapeLabel<'a, 'b> { +impl<'a> fmt::Write for EscapeLabel<'a> { fn write_str(&mut self, mut s: &str) -> fmt::Result { loop { let mut char_indices = s.char_indices(); diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 5c685284e858..76de9f3beb24 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -1,14 +1,13 @@ use std::borrow::Cow; use std::fmt; use std::fmt::{Display, Formatter}; -use std::path::PathBuf; use polars_core::datatypes::AnyValue; use polars_core::schema::Schema; use polars_io::RowIndex; use recursive::recursive; -use super::ir::dot::PathsDisplay; +use self::ir::dot::ScanSourcesDisplay; use crate::prelude::*; pub struct IRDisplay<'a> { @@ -56,7 +55,7 @@ impl AsExpr for ExprIR { fn write_scan( f: &mut Formatter, name: &str, - path: &[PathBuf], + sources: &ScanSources, indent: usize, n_columns: i64, total_columns: usize, @@ -64,7 +63,12 @@ fn write_scan( slice: Option<(i64, usize)>, row_index: Option<&RowIndex>, ) -> fmt::Result { - write!(f, "{:indent$}{name} SCAN {}", "", PathsDisplay(path))?; + write!( + f, + "{:indent$}{name} SCAN {}", + "", + ScanSourcesDisplay(sources) + )?; let total_columns = total_columns - usize::from(row_index.is_some()); if n_columns > 0 { @@ -171,7 +175,7 @@ impl<'a> IRDisplay<'a> { write_scan( f, "PYTHON", - &[], + &ScanSources::default(), indent, n_columns, total_columns, @@ -221,7 +225,7 @@ impl<'a> IRDisplay<'a> { self.with_root(*input)._format(f, sub_indent) }, Scan { - paths, + sources, file_info, predicate, scan_type, @@ -239,7 +243,7 @@ impl<'a> IRDisplay<'a> { write_scan( f, scan_type.into(), - paths, + sources, indent, n_columns, file_info.schema.len(), @@ -482,7 +486,6 @@ impl<'a> Display for ExprIRDisplay<'a> { }, } }, - Nth(i) => write!(f, "nth({i})"), Len => write!(f, "len()"), Explode(expr) => { let expr = self.with_root(expr); @@ -588,14 +591,14 @@ impl<'a> Display for ExprIRDisplay<'a> { }, Cast { expr, - data_type, + dtype, options, } => { self.with_root(expr).fmt(f)?; if options.strict() { - write!(f, ".strict_cast({data_type:?})") + write!(f, ".strict_cast({dtype:?})") } else { - write!(f, ".cast({data_type:?})") + write!(f, ".cast({dtype:?})") } }, Ternary { @@ -639,7 +642,6 @@ impl<'a> Display for ExprIRDisplay<'a> { write!(f, "{input}.slice(offset={offset}, length={length})") }, - Wildcard => write!(f, "*"), }?; match self.output_name { diff --git a/crates/polars-plan/src/plans/ir/inputs.rs b/crates/polars-plan/src/plans/ir/inputs.rs index b00c91cddae4..2a7c14e300de 100644 --- a/crates/polars-plan/src/plans/ir/inputs.rs +++ b/crates/polars-plan/src/plans/ir/inputs.rs @@ -101,7 +101,7 @@ impl IR { options: *options, }, Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -114,7 +114,7 @@ impl IR { new_predicate = exprs.pop() } Scan { - paths: paths.clone(), + sources: sources.clone(), file_info: file_info.clone(), hive_parts: hive_parts.clone(), output_schema: output_schema.clone(), diff --git a/crates/polars-plan/src/plans/ir/mod.rs b/crates/polars-plan/src/plans/ir/mod.rs index 11d6b610ed02..a9eb45b6406f 100644 --- a/crates/polars-plan/src/plans/ir/mod.rs +++ b/crates/polars-plan/src/plans/ir/mod.rs @@ -1,22 +1,26 @@ mod dot; mod format; mod inputs; +mod scan_sources; mod schema; pub(crate) mod tree_format; use std::borrow::Cow; use std::fmt; -use std::path::PathBuf; -pub use dot::IRDotDisplay; +pub use dot::{EscapeLabel, IRDotDisplay, PathsDisplay, ScanSourcesDisplay}; pub use format::{ExprIRDisplay, IRDisplay}; use hive::HivePartitions; use polars_core::prelude::*; use polars_utils::idx_vec::UnitVec; use polars_utils::unitvec; +pub use scan_sources::{ScanSourceIter, ScanSourceRef, ScanSources}; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; use crate::prelude::*; +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub struct IRPlan { pub lp_top: Node, pub lp_arena: Arena, @@ -33,6 +37,7 @@ pub struct IRPlanRef<'a> { /// [`IR`] is a representation of [`DslPlan`] with [`Node`]s which are allocated in an [`Arena`] /// In this IR the logical plan has access to the full dataset. #[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub enum IR { #[cfg(feature = "python")] PythonScan { @@ -48,7 +53,7 @@ pub enum IR { predicate: ExprIR, }, Scan { - paths: Arc>, + sources: ScanSources, file_info: FileInfo, hive_parts: Option>>, predicate: Option, @@ -105,6 +110,7 @@ pub enum IR { keys: Vec, aggs: Vec, schema: SchemaRef, + #[cfg_attr(feature = "ir_serde", serde(skip))] apply: Option>, maintain_order: bool, options: Arc, @@ -125,11 +131,11 @@ pub enum IR { }, Distinct { input: Node, - options: DistinctOptions, + options: DistinctOptionsIR, }, MapFunction { input: Node, - function: FunctionNode, + function: FunctionIR, }, Union { inputs: Vec, @@ -220,7 +226,7 @@ impl<'a> IRPlanRef<'a> { return None; }; - let FunctionNode::Pipeline { original, .. } = function else { + let FunctionIR::Pipeline { original, .. } = function else { return None; }; diff --git a/crates/polars-plan/src/plans/ir/scan_sources.rs b/crates/polars-plan/src/plans/ir/scan_sources.rs new file mode 100644 index 000000000000..1bdb92fda904 --- /dev/null +++ b/crates/polars-plan/src/plans/ir/scan_sources.rs @@ -0,0 +1,262 @@ +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use polars_core::error::{feature_gated, PolarsResult}; +use polars_utils::mmap::MemSlice; +use polars_utils::pl_str::PlSmallStr; + +use super::DslScanSources; + +/// Set of sources to scan from +/// +/// This is can either be a list of paths to files, opened files or in-memory buffers. Mixing of +/// buffers is not currently possible. +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone)] +pub enum ScanSources { + Paths(Arc<[PathBuf]>), + + #[cfg_attr(feature = "serde", serde(skip))] + Files(Arc<[File]>), + #[cfg_attr(feature = "serde", serde(skip))] + Buffers(Arc<[bytes::Bytes]>), +} + +/// A reference to a single item in [`ScanSources`] +#[derive(Debug, Clone, Copy)] +pub enum ScanSourceRef<'a> { + Path(&'a Path), + File(&'a File), + Buffer(&'a bytes::Bytes), +} + +/// An iterator for [`ScanSources`] +pub struct ScanSourceIter<'a> { + sources: &'a ScanSources, + offset: usize, +} + +impl Default for ScanSources { + fn default() -> Self { + Self::Buffers(Arc::default()) + } +} + +impl std::hash::Hash for ScanSources { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + + // @NOTE: This is a bit crazy + // + // We don't really want to hash the file descriptors or the whole buffers so for now we + // just settle with the fact that the memory behind Arc's does not really move. Therefore, + // we can just hash the pointer. + match self { + Self::Paths(paths) => paths.hash(state), + Self::Files(files) => files.as_ptr().hash(state), + Self::Buffers(buffers) => buffers.as_ptr().hash(state), + } + } +} + +impl PartialEq for ScanSources { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ScanSources::Paths(l), ScanSources::Paths(r)) => l == r, + (ScanSources::Files(l), ScanSources::Files(r)) => std::ptr::eq(l.as_ptr(), r.as_ptr()), + (ScanSources::Buffers(l), ScanSources::Buffers(r)) => { + std::ptr::eq(l.as_ptr(), r.as_ptr()) + }, + _ => false, + } + } +} + +impl Eq for ScanSources {} + +impl ScanSources { + pub fn iter(&self) -> ScanSourceIter { + ScanSourceIter { + sources: self, + offset: 0, + } + } + + pub fn to_dsl(self, is_expanded: bool) -> DslScanSources { + DslScanSources { + sources: self, + is_expanded, + } + } + + /// Are the sources all paths? + pub fn is_paths(&self) -> bool { + matches!(self, Self::Paths(_)) + } + + /// Try cast the scan sources to [`ScanSources::Paths`] + pub fn as_paths(&self) -> Option<&[PathBuf]> { + match self { + Self::Paths(paths) => Some(paths.as_ref()), + Self::Files(_) | Self::Buffers(_) => None, + } + } + + /// Try cast the scan sources to [`ScanSources::Paths`] with a clone + pub fn into_paths(&self) -> Option> { + match self { + Self::Paths(paths) => Some(paths.clone()), + Self::Files(_) | Self::Buffers(_) => None, + } + } + + /// Try get the first path in the scan sources + pub fn first_path(&self) -> Option<&Path> { + match self { + Self::Paths(paths) => paths.first().map(|p| p.as_path()), + Self::Files(_) | Self::Buffers(_) => None, + } + } + + /// Is the first path a cloud URL? + pub fn is_cloud_url(&self) -> bool { + self.first_path().is_some_and(polars_io::is_cloud_url) + } + + pub fn len(&self) -> usize { + match self { + Self::Paths(s) => s.len(), + Self::Files(s) => s.len(), + Self::Buffers(s) => s.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn first(&self) -> Option { + self.get(0) + } + + /// Turn the [`ScanSources`] into some kind of identifier + pub fn id(&self) -> PlSmallStr { + if self.is_empty() { + return PlSmallStr::from_static("EMPTY"); + } + + match self { + Self::Paths(paths) => { + PlSmallStr::from_str(paths.first().unwrap().to_string_lossy().as_ref()) + }, + Self::Files(_) => PlSmallStr::from_static("OPEN_FILES"), + Self::Buffers(_) => PlSmallStr::from_static("IN_MEMORY"), + } + } + + /// Get the scan source at specific address + pub fn get(&self, idx: usize) -> Option { + match self { + Self::Paths(paths) => paths.get(idx).map(|p| ScanSourceRef::Path(p)), + Self::Files(files) => files.get(idx).map(ScanSourceRef::File), + Self::Buffers(buffers) => buffers.get(idx).map(ScanSourceRef::Buffer), + } + } + + /// Get the scan source at specific address + /// + /// # Panics + /// + /// If the `idx` is out of range. + #[track_caller] + pub fn at(&self, idx: usize) -> ScanSourceRef { + self.get(idx).unwrap() + } +} + +impl<'a> ScanSourceRef<'a> { + /// Get the name for `include_paths` + pub fn to_include_path_name(&self) -> &str { + match self { + Self::Path(path) => path.to_str().unwrap(), + Self::File(_) => "open-file", + Self::Buffer(_) => "in-mem", + } + } + + /// Turn the scan source into a memory slice + pub fn to_memslice(&self) -> PolarsResult { + self.to_memslice_possibly_async(false, None, 0) + } + + pub fn to_memslice_async_latest(&self, run_async: bool) -> PolarsResult { + match self { + ScanSourceRef::Path(path) => { + let file = if run_async { + feature_gated!("cloud", { + polars_io::file_cache::FILE_CACHE + .get_entry(path.to_str().unwrap()) + // Safety: This was initialized by schema inference. + .unwrap() + .try_open_assume_latest()? + }) + } else { + polars_utils::open_file(path)? + }; + + MemSlice::from_file(&file) + }, + ScanSourceRef::File(file) => MemSlice::from_file(file), + ScanSourceRef::Buffer(buff) => Ok(MemSlice::from_bytes((*buff).clone())), + } + } + + pub fn to_memslice_possibly_async( + &self, + run_async: bool, + #[cfg(feature = "cloud")] cache_entries: Option< + &Vec>, + >, + #[cfg(not(feature = "cloud"))] cache_entries: Option<&()>, + index: usize, + ) -> PolarsResult { + match self { + Self::Path(path) => { + let file = if run_async { + feature_gated!("cloud", { + cache_entries.unwrap()[index].try_open_check_latest()? + }) + } else { + polars_utils::open_file(path)? + }; + + MemSlice::from_file(&file) + }, + Self::File(file) => MemSlice::from_file(file), + Self::Buffer(buff) => Ok(MemSlice::from_bytes((*buff).clone())), + } + } +} + +impl<'a> Iterator for ScanSourceIter<'a> { + type Item = ScanSourceRef<'a>; + + fn next(&mut self) -> Option { + let item = match self.sources { + ScanSources::Paths(paths) => ScanSourceRef::Path(paths.get(self.offset)?), + ScanSources::Files(files) => ScanSourceRef::File(files.get(self.offset)?), + ScanSources::Buffers(buffers) => ScanSourceRef::Buffer(buffers.get(self.offset)?), + }; + + self.offset += 1; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.sources.len() - self.offset; + (len, Some(len)) + } +} + +impl<'a> ExactSizeIterator for ScanSourceIter<'a> {} diff --git a/crates/polars-plan/src/plans/ir/schema.rs b/crates/polars-plan/src/plans/ir/schema.rs index 5b5042e50377..1586463a8c0f 100644 --- a/crates/polars-plan/src/plans/ir/schema.rs +++ b/crates/polars-plan/src/plans/ir/schema.rs @@ -107,4 +107,60 @@ impl IR { }; Cow::Borrowed(schema) } + + /// Get the schema of the logical plan node, using caching. + #[recursive] + pub fn schema_with_cache<'a>( + node: Node, + arena: &'a Arena, + cache: &mut PlHashMap>, + ) -> Arc { + use IR::*; + if let Some(schema) = cache.get(&node) { + return schema.clone(); + } + + let schema = match arena.get(node) { + #[cfg(feature = "python")] + PythonScan { options } => options + .output_schema + .as_ref() + .unwrap_or(&options.schema) + .clone(), + Union { inputs, .. } => IR::schema_with_cache(inputs[0], arena, cache), + HConcat { schema, .. } => schema.clone(), + Cache { input, .. } + | Sort { input, .. } + | Filter { input, .. } + | Distinct { input, .. } + | Sink { input, .. } + | Slice { input, .. } => IR::schema_with_cache(*input, arena, cache), + Scan { + output_schema, + file_info, + .. + } => output_schema.as_ref().unwrap_or(&file_info.schema).clone(), + DataFrameScan { + schema, + output_schema, + .. + } => output_schema.as_ref().unwrap_or(schema).clone(), + Select { schema, .. } + | Reduce { schema, .. } + | GroupBy { schema, .. } + | Join { schema, .. } + | HStack { schema, .. } + | ExtContext { schema, .. } + | SimpleProjection { + columns: schema, .. + } => schema.clone(), + MapFunction { input, function } => { + let input_schema = IR::schema_with_cache(*input, arena, cache); + function.schema(&input_schema).unwrap().into_owned() + }, + Invalid => unreachable!(), + }; + cache.insert(node, schema.clone()); + schema + } } diff --git a/crates/polars-plan/src/plans/ir/tree_format.rs b/crates/polars-plan/src/plans/ir/tree_format.rs index 72336bf4e2e2..3ecb9507fb9f 100644 --- a/crates/polars-plan/src/plans/ir/tree_format.rs +++ b/crates/polars-plan/src/plans/ir/tree_format.rs @@ -26,17 +26,15 @@ impl fmt::Display for TreeFmtAExpr<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self.0 { AExpr::Explode(_) => "explode", - AExpr::Alias(_, name) => return write!(f, "alias({})", name.as_ref()), - AExpr::Column(name) => return write!(f, "col({})", name.as_ref()), + AExpr::Alias(_, name) => return write!(f, "alias({})", name), + AExpr::Column(name) => return write!(f, "col({})", name), AExpr::Literal(lv) => return write!(f, "lit({lv:?})"), AExpr::BinaryExpr { op, .. } => return write!(f, "binary: {}", op), - AExpr::Cast { - data_type, options, .. - } => { + AExpr::Cast { dtype, options, .. } => { return if options.strict() { - write!(f, "strict cast({})", data_type) + write!(f, "strict cast({})", dtype) } else { - write!(f, "cast({})", data_type) + write!(f, "cast({})", dtype) } }, AExpr::Sort { options, .. } => { @@ -69,10 +67,8 @@ impl fmt::Display for TreeFmtAExpr<'_> { }, AExpr::Function { function, .. } => return write!(f, "function: {function}"), AExpr::Window { .. } => "window", - AExpr::Wildcard => "*", AExpr::Slice { .. } => "slice", AExpr::Len => constants::LEN, - AExpr::Nth(v) => return write!(f, "nth({})", v), }; write!(f, "{s}") diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index c0dcab76d3c6..48f2e8aa7e45 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -8,7 +8,7 @@ use polars_utils::hashing::hash_to_partition; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::constants::{get_literal_name, LITERAL_NAME}; +use crate::constants::get_literal_name; use crate::prelude::*; #[derive(Clone, PartialEq)] @@ -18,7 +18,7 @@ pub enum LiteralValue { /// A binary true or false. Boolean(bool), /// A UTF8 encoded string type. - String(String), + String(PlSmallStr), /// A raw binary array Binary(Vec), /// An unsigned 8-bit integer number. @@ -51,7 +51,7 @@ pub enum LiteralValue { Range { low: i64, high: i64, - data_type: DataType, + dtype: DataType, }, #[cfg(feature = "dtype-date")] Date(i32), @@ -67,22 +67,22 @@ pub enum LiteralValue { // Used for dynamic languages Int(i128), // Dynamic string, still needs to be made concrete. - StrCat(String), + StrCat(PlSmallStr), } impl LiteralValue { /// Get the output name as `&str`. - pub(crate) fn output_name(&self) -> &str { + pub(crate) fn output_name(&self) -> &PlSmallStr { match self { LiteralValue::Series(s) => s.name(), - _ => LITERAL_NAME, + _ => get_literal_name(), } } - /// Get the output name as [`ColumnName`]. - pub(crate) fn output_column_name(&self) -> ColumnName { + /// Get the output name as [`PlSmallStr`]. + pub(crate) fn output_column_name(&self) -> &PlSmallStr { match self { - LiteralValue::Series(s) => ColumnName::from(s.name()), + LiteralValue::Series(s) => s.name(), _ => get_literal_name(), } } @@ -139,12 +139,8 @@ impl LiteralValue { Int(v) => materialize_dyn_int(*v), Float(v) => AnyValue::Float64(*v), StrCat(v) => AnyValue::String(v), - Range { - low, - high, - data_type, - } => { - let opt_s = match data_type { + Range { low, high, dtype } => { + let opt_s = match dtype { DataType::Int32 => { if *low < i32::MIN as i64 || *high > i32::MAX as i64 { return None; @@ -152,12 +148,14 @@ impl LiteralValue { let low = *low as i32; let high = *high as i32; - new_int_range::(low, high, 1, "range").ok() + new_int_range::(low, high, 1, PlSmallStr::from_static("range")) + .ok() }, DataType::Int64 => { let low = *low; let high = *high; - new_int_range::(low, high, 1, "range").ok() + new_int_range::(low, high, 1, PlSmallStr::from_static("range")) + .ok() }, DataType::UInt32 => { if *low < 0 || *high > u32::MAX as i64 { @@ -165,7 +163,8 @@ impl LiteralValue { } let low = *low as u32; let high = *high as u32; - new_int_range::(low, high, 1, "range").ok() + new_int_range::(low, high, 1, PlSmallStr::from_static("range")) + .ok() }, _ => return None, }; @@ -201,7 +200,7 @@ impl LiteralValue { LiteralValue::Decimal(_, scale) => DataType::Decimal(None, Some(*scale)), LiteralValue::String(_) => DataType::String, LiteralValue::Binary(_) => DataType::Binary, - LiteralValue::Range { data_type, .. } => data_type.clone(), + LiteralValue::Range { dtype, .. } => dtype.clone(), #[cfg(feature = "dtype-date")] LiteralValue::Date(_) => DataType::Date, #[cfg(feature = "dtype-datetime")] @@ -217,6 +216,17 @@ impl LiteralValue { LiteralValue::StrCat(_) => DataType::Unknown(UnknownKind::Str), } } + + pub(crate) fn new_idxsize(value: IdxSize) -> Self { + #[cfg(feature = "bigidx")] + { + LiteralValue::UInt64(value) + } + #[cfg(not(feature = "bigidx"))] + { + LiteralValue::UInt32(value) + } + } } pub trait Literal { @@ -237,15 +247,21 @@ pub trait TypedLiteral: Literal { impl TypedLiteral for String {} impl TypedLiteral for &str {} -impl Literal for String { +impl Literal for PlSmallStr { fn lit(self) -> Expr { Expr::Literal(LiteralValue::String(self)) } } +impl Literal for String { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::String(PlSmallStr::from_string(self))) + } +} + impl<'a> Literal for &'a str { fn lit(self) -> Expr { - Expr::Literal(LiteralValue::String(self.to_string())) + Expr::Literal(LiteralValue::String(PlSmallStr::from_str(self))) } } @@ -267,7 +283,7 @@ impl TryFrom> for LiteralValue { match value { AnyValue::Null => Ok(Self::Null), AnyValue::Boolean(b) => Ok(Self::Boolean(b)), - AnyValue::String(s) => Ok(Self::String(s.to_string())), + AnyValue::String(s) => Ok(Self::String(PlSmallStr::from_str(s))), AnyValue::Binary(b) => Ok(Self::Binary(b.to_vec())), #[cfg(feature = "dtype-u8")] AnyValue::UInt8(u) => Ok(Self::UInt8(u)), @@ -294,16 +310,16 @@ impl TryFrom> for LiteralValue { #[cfg(feature = "dtype-time")] AnyValue::Time(v) => Ok(LiteralValue::Time(v)), AnyValue::List(l) => Ok(Self::Series(SpecialEq::new(l))), - AnyValue::StringOwned(o) => Ok(Self::String(o.into())), + AnyValue::StringOwned(o) => Ok(Self::String(o)), #[cfg(feature = "dtype-categorical")] AnyValue::Categorical(c, rev_mapping, arr) | AnyValue::Enum(c, rev_mapping, arr) => { if arr.is_null() { - Ok(Self::String(rev_mapping.get(c).to_string())) + Ok(Self::String(PlSmallStr::from_str(rev_mapping.get(c)))) } else { unsafe { - Ok(Self::String( - arr.deref_unchecked().value(c as usize).to_string(), - )) + Ok(Self::String(PlSmallStr::from_str( + arr.deref_unchecked().value(c as usize), + ))) } } }, @@ -482,14 +498,10 @@ impl Hash for LiteralValue { rng = rng.rotate_right(17).wrapping_add(RANDOM); } }, - LiteralValue::Range { - low, - high, - data_type, - } => { + LiteralValue::Range { low, high, dtype } => { low.hash(state); high.hash(state); - data_type.hash(state) + dtype.hash(state) }, _ => { if let Some(v) = self.to_any_value() { diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index bb25c627170b..92eeb783bf76 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -1,6 +1,5 @@ use std::fmt; use std::fmt::Debug; -use std::path::PathBuf; use std::sync::{Arc, Mutex, RwLock}; use hive::HivePartitions; @@ -51,8 +50,6 @@ pub use schema::*; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; -pub type ColumnName = Arc; - #[derive(Clone, Copy, Debug)] pub enum Context { /// Any operation that is done on groups @@ -61,6 +58,13 @@ pub enum Context { Default, } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone)] +pub struct DslScanSources { + pub sources: ScanSources, + pub is_expanded: bool, +} + // https://stackoverflow.com/questions/1031076/what-are-projection-and-selection #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum DslPlan { @@ -78,7 +82,7 @@ pub enum DslPlan { cache_hits: u32, }, Scan { - paths: Arc>, bool)>>, + sources: Arc>, // Option as this is mostly materialized on the IR phase. // During conversion we update the value in the DSL as well // This is to cater to use cases where parts of a `LazyFrame` @@ -119,8 +123,11 @@ pub enum DslPlan { Join { input_left: Arc, input_right: Arc, + // Invariant: left_on and right_on are equal length. left_on: Vec, right_on: Vec, + // Invariant: Either left_on/right_on or predicates is set (non-empty). + predicates: Vec, options: Arc, }, /// Adding columns to the table without a Join @@ -132,7 +139,7 @@ pub enum DslPlan { /// Remove duplicates from the table Distinct { input: Arc, - options: DistinctOptions, + options: DistinctOptionsDSL, }, /// Sort the table Sort { @@ -192,11 +199,11 @@ impl Clone for DslPlan { Self::PythonScan { options } => Self::PythonScan { options: options.clone() }, Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() }, Self::Cache { input, id, cache_hits } => Self::Cache { input: input.clone(), id: id.clone(), cache_hits: cache_hits.clone() }, - Self::Scan { paths, file_info, hive_parts, predicate, file_options, scan_type } => Self::Scan { paths: paths.clone(), file_info: file_info.clone(), hive_parts: hive_parts.clone(), predicate: predicate.clone(), file_options: file_options.clone(), scan_type: scan_type.clone() }, + Self::Scan { sources, file_info, hive_parts, predicate, file_options, scan_type } => Self::Scan { sources: sources.clone(), file_info: file_info.clone(), hive_parts: hive_parts.clone(), predicate: predicate.clone(), file_options: file_options.clone(), scan_type: scan_type.clone() }, Self::DataFrameScan { df, schema, output_schema, filter: selection } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(), output_schema: output_schema.clone(), filter: selection.clone() }, Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() }, Self::GroupBy { input, keys, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() }, - Self::Join { input_left, input_right, left_on, right_on, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone() }, + Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() }, Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), options: options.clone() }, Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() }, Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() }, @@ -213,7 +220,7 @@ impl Clone for DslPlan { impl Default for DslPlan { fn default() -> Self { - let df = DataFrame::new::(vec![]).unwrap(); + let df = DataFrame::empty(); let schema = df.schema(); DslPlan::DataFrameScan { df: Arc::new(df), @@ -247,7 +254,12 @@ impl DslPlan { let mut lp_arena = Arena::with_capacity(16); let mut expr_arena = Arena::with_capacity(16); - let node = to_alp(self, &mut expr_arena, &mut lp_arena, true, true)?; + let node = to_alp( + self, + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::default(), + )?; let plan = IRPlan::new(node, lp_arena, expr_arena); Ok(plan) diff --git a/crates/polars-plan/src/plans/optimizer/cache_states.rs b/crates/polars-plan/src/plans/optimizer/cache_states.rs index b66f73a18ae8..da13d047d43f 100644 --- a/crates/polars-plan/src/plans/optimizer/cache_states.rs +++ b/crates/polars-plan/src/plans/optimizer/cache_states.rs @@ -6,7 +6,7 @@ fn get_upper_projections( parent: Node, lp_arena: &Arena, expr_arena: &Arena, - names_scratch: &mut Vec, + names_scratch: &mut Vec, found_required_columns: &mut bool, ) -> bool { let parent = lp_arena.get(parent); @@ -15,7 +15,7 @@ fn get_upper_projections( // During projection pushdown all accumulated. match parent { SimpleProjection { columns, .. } => { - let iter = columns.iter_names().map(|s| ColumnName::from(s.as_str())); + let iter = columns.iter_names_cloned(); names_scratch.extend(iter); *found_required_columns = true; false @@ -138,7 +138,7 @@ pub(super) fn set_cache_states( parents: Vec, cache_nodes: Vec, // Union over projected names. - names_union: PlHashSet, + names_union: PlHashSet, // Union over predicates. predicate_union: PlHashMap, } @@ -264,11 +264,7 @@ pub(super) fn set_cache_states( // all columns if !found_required_columns { let schema = lp.schema(lp_arena); - v.names_union.extend( - schema - .iter_names() - .map(|name| ColumnName::from(name.as_str())), - ); + v.names_union.extend(schema.iter_names_cloned()); } } frame.cache_id = Some(*id); @@ -343,15 +339,15 @@ pub(super) fn set_cache_states( // order we discovered all values. let child_schema = child_lp.schema(lp_arena); let child_schema = child_schema.as_ref(); - let projection: Vec<_> = child_schema + let projection = child_schema .iter_names() - .flat_map(|name| columns.get(name.as_str()).map(|name| name.as_ref())) - .collect(); + .flat_map(|name| columns.get(name.as_str()).cloned()) + .collect::>(); let new_child = lp_arena.add(child_lp); let lp = IRBuilder::new(new_child, expr_arena, lp_arena) - .project_simple(projection.iter().copied()) + .project_simple(projection) .unwrap() .build(); diff --git a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs index 160a9cbdbc0f..b3f52c6e30a9 100644 --- a/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs +++ b/crates/polars-plan/src/plans/optimizer/cluster_with_columns.rs @@ -8,9 +8,9 @@ use polars_utils::vec::inplace_zip_filtermap; use super::aexpr::AExpr; use super::ir::IR; -use super::{aexpr_to_leaf_names_iter, ColumnName}; +use super::{aexpr_to_leaf_names_iter, PlSmallStr}; -type ColumnMap = PlHashMap; +type ColumnMap = PlHashMap; fn column_map_finalize_bitset(bitset: &mut MutableBitmap, column_map: &ColumnMap) { assert!(bitset.len() <= column_map.len()); @@ -19,7 +19,7 @@ fn column_map_finalize_bitset(bitset: &mut MutableBitmap, column_map: &ColumnMap bitset.extend_constant(column_map.len() - size, false); } -fn column_map_set(bitset: &mut MutableBitmap, column_map: &mut ColumnMap, column: ColumnName) { +fn column_map_set(bitset: &mut MutableBitmap, column_map: &mut ColumnMap, column: PlSmallStr) { let size = column_map.len(); column_map .entry(column) @@ -92,7 +92,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) column_map_set( &mut input_genset, column_map, - input_expr.output_name_arc().clone(), + input_expr.output_name().clone(), ); } @@ -132,14 +132,12 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) return Some((expr, liveset)); } - let column_name = expr.output_name_arc(); + let column_name = expr.output_name(); let is_pushable = if let Some(idx) = column_map.get(column_name) { let does_input_alias_also_expr = input_genset.get(*idx); let is_alias_live_in_current = current_liveset.get(*idx); if does_input_alias_also_expr && !is_alias_live_in_current { - let column_name = column_name.as_ref(); - // @NOTE: Pruning of re-assigned columns // // We checked if this expression output is also assigned by the input and @@ -190,7 +188,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) // This will pushdown the expressions that "has an output column that is mentioned by // neighbour columns, but all those neighbours were being pushed down". for candidate in potential_pushable.iter().copied() { - let column_name = current_exprs[candidate].output_name_arc(); + let column_name = current_exprs[candidate].output_name(); let column_idx = column_map.get(column_name).unwrap(); current_liveset.clear(); @@ -258,7 +256,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &Arena) if do_pushdown { needs_simple_projection = has_seen_unpushable; - let column = expr.output_name_arc().as_ref(); + let column = expr.output_name().as_ref(); // @NOTE: we cannot just use the index here, as there might be renames that sit // earlier in the schema let datatype = current_schema.get(column).unwrap(); diff --git a/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs b/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs index e4c0ac87151a..266ac6cd3335 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_and_project.rs @@ -52,10 +52,10 @@ impl OptimizationRule for SimpleProjectionAndCollapse { let exprs = expr .iter() - .map(|e| e.output_name_arc().clone()) + .map(|e| e.output_name().clone()) .collect::>(); let alp = IRBuilder::new(*input, expr_arena, lp_arena) - .project_simple(exprs.iter().map(|s| s.as_ref())) + .project_simple(exprs.iter().cloned()) .ok()? .build(); @@ -128,7 +128,11 @@ impl OptimizationRule for SimpleProjectionAndCollapse { input, by_column, slice, - sort_options, + sort_options: + sort_options @ SortMultipleOptions { + maintain_order: false, // `maintain_order=True` is influenced by result of earlier sorts + .. + }, } => match lp_arena.get(*input) { Sort { input: inner, diff --git a/crates/polars-plan/src/plans/optimizer/count_star.rs b/crates/polars-plan/src/plans/optimizer/count_star.rs index cc5a841fb3de..1f20c83f6a87 100644 --- a/crates/polars-plan/src/plans/optimizer/count_star.rs +++ b/crates/polars-plan/src/plans/optimizer/count_star.rs @@ -31,8 +31,8 @@ impl OptimizationRule for CountStar { let alp = IR::MapFunction { input: placeholder_node, - function: FunctionNode::Count { - paths: count_star_expr.paths, + function: FunctionIR::FastCount { + sources: count_star_expr.sources, scan_type: count_star_expr.scan_type, alias: count_star_expr.alias, }, @@ -49,11 +49,11 @@ struct CountStarExpr { // Top node of the projection to replace node: Node, // Paths to the input files - paths: Arc>, + sources: ScanSources, // File Type scan_type: FileScan, // Column Alias - alias: Option>, + alias: Option, } // Visit the logical plan and return CountStarExpr with the expr information gathered @@ -66,12 +66,34 @@ fn visit_logical_plan_for_scan_paths( ) -> Option { match lp_arena.get(node) { IR::Union { inputs, .. } => { + enum MutableSources { + Paths(Vec), + Buffers(Vec), + } + let mut scan_type: Option = None; - let mut paths = Vec::with_capacity(inputs.len()); + let mut sources = None; for input in inputs { match visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, true) { Some(expr) => { - paths.extend(expr.paths.iter().cloned()); + match (expr.sources, &mut sources) { + ( + ScanSources::Paths(paths), + Some(MutableSources::Paths(ref mut mutable_paths)), + ) => mutable_paths.extend_from_slice(&paths[..]), + (ScanSources::Paths(paths), None) => { + sources = Some(MutableSources::Paths(paths.to_vec())) + }, + ( + ScanSources::Buffers(buffers), + Some(MutableSources::Buffers(ref mut mutable_buffers)), + ) => mutable_buffers.extend_from_slice(&buffers[..]), + (ScanSources::Buffers(buffers), None) => { + sources = Some(MutableSources::Buffers(buffers.to_vec())) + }, + _ => return None, + } + match &scan_type { None => scan_type = Some(expr.scan_type), Some(scan_type) => { @@ -88,16 +110,20 @@ fn visit_logical_plan_for_scan_paths( } } Some(CountStarExpr { - paths: paths.into(), + sources: match sources { + Some(MutableSources::Paths(paths)) => ScanSources::Paths(paths.into()), + Some(MutableSources::Buffers(buffers)) => ScanSources::Buffers(buffers.into()), + None => ScanSources::default(), + }, scan_type: scan_type.unwrap(), node, alias: None, }) }, IR::Scan { - scan_type, paths, .. + scan_type, sources, .. } if !matches!(scan_type, FileScan::Anonymous { .. }) => Some(CountStarExpr { - paths: paths.clone(), + sources: sources.clone(), scan_type: scan_type.clone(), node, alias: None, @@ -125,7 +151,7 @@ fn visit_logical_plan_for_scan_paths( } } -fn is_valid_count_expr(e: &ExprIR, expr_arena: &Arena) -> (bool, Option>) { +fn is_valid_count_expr(e: &ExprIR, expr_arena: &Arena) -> (bool, Option) { match expr_arena.get(e.node()) { AExpr::Len => (true, e.get_alias().cloned()), _ => (false, None), diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs index b069c9bb9309..6b7763760fa1 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs @@ -1,4 +1,5 @@ use hashbrown::hash_map::RawEntryMut; +use polars_utils::format_pl_smallstr; use polars_utils::vec::CapacityByFactor; use super::*; @@ -7,7 +8,6 @@ use crate::prelude::visitor::AexprNode; const SERIES_LIMIT: usize = 1000; -use ahash::RandomState; use polars_core::hashing::_boost_hash_combine; #[derive(Debug, Clone)] @@ -45,7 +45,7 @@ impl ProjectionExprs { pub(super) struct Identifier { inner: Option, last_node: Option, - hb: RandomState, + hb: PlRandomState, } impl Identifier { @@ -53,7 +53,7 @@ impl Identifier { Self { inner: None, last_node: None, - hb: RandomState::with_seed(0), + hb: PlRandomState::with_seed(0), } } @@ -75,8 +75,8 @@ impl Identifier { self.inner.is_some() } - fn materialize(&self) -> String { - format!("{}{:#x}", CSE_REPLACED, self.materialized_hash()) + fn materialize(&self) -> PlSmallStr { + format_pl_smallstr!("{}{:#x}", CSE_REPLACED, self.materialized_hash()) } fn materialized_hash(&self) -> u64 { @@ -591,7 +591,7 @@ impl RewritingVisitor for CommonSubExprRewriter<'_> { ); let name = id.materialize(); - node.assign(AExpr::col(name.as_ref()), arena); + node.assign(AExpr::col(name), arena); self.rewritten = true; Ok(node) @@ -724,7 +724,7 @@ impl CommonSubExprOptimizer { // intermediate temporary names starting with the `CSE_REPLACED` constant. if !e.has_alias() { let name = ae_node.to_field(schema, expr_arena)?.name; - out_e.set_alias(ColumnName::from(name.as_str())); + out_e.set_alias(name.clone()); } out_e }; @@ -734,7 +734,7 @@ impl CommonSubExprOptimizer { for id in self.replaced_identifiers.inner.keys() { let (node, _count) = self.se_count.get(id, expr_arena).unwrap(); let name = id.materialize(); - let out_e = ExprIR::new(*node, OutputName::Alias(ColumnName::from(name))); + let out_e = ExprIR::new(*node, OutputName::Alias(name)); new_expr.push(out_e) } let expr = diff --git a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs index cde7a0dea710..075414597edf 100644 --- a/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/cse/cse_lp.rs @@ -4,7 +4,6 @@ use super::*; use crate::prelude::visitor::IRNode; mod identifier_impl { - use ahash::RandomState; use polars_core::hashing::_boost_hash_combine; use super::*; @@ -17,7 +16,7 @@ mod identifier_impl { pub(super) struct Identifier { inner: Option, last_node: Option, - hb: RandomState, + hb: PlRandomState, } impl Identifier { @@ -48,7 +47,7 @@ mod identifier_impl { Self { inner: None, last_node: None, - hb: RandomState::with_seed(0), + hb: PlRandomState::with_seed(0), } } diff --git a/crates/polars-plan/src/plans/optimizer/fused.rs b/crates/polars-plan/src/plans/optimizer/fused.rs index d548147f65ce..cb84ca1b385f 100644 --- a/crates/polars-plan/src/plans/optimizer/fused.rs +++ b/crates/polars-plan/src/plans/optimizer/fused.rs @@ -106,10 +106,7 @@ impl OptimizationRule for FusedArithmetic { let node = expr_arena.add(fma); // we reordered the arguments, so we don't obey the left expression output name // rule anymore, that's why we alias - Ok(Some(Alias( - node, - ColumnName::from(output_field.name.as_str()), - ))) + Ok(Some(Alias(node, output_field.name.clone()))) }, _ => unreachable!(), }, diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index 58e84a607d06..34dc6dca9a29 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -18,7 +18,6 @@ mod join_utils; mod predicate_pushdown; mod projection_pushdown; mod simplify_expr; -mod simplify_functions; mod slice_pushdown_expr; mod slice_pushdown_lp; mod stack_opt; @@ -34,7 +33,7 @@ use slice_pushdown_lp::SlicePushDown; pub use stack_opt::{OptimizationRule, StackOptimizer}; use self::flatten_union::FlattenUnionRule; -pub use crate::frame::{AllowedOptimizations, OptState}; +pub use crate::frame::{AllowedOptimizations, OptFlags}; pub use crate::plans::conversion::type_coercion::TypeCoercionRule; use crate::plans::optimizer::count_star::CountStar; #[cfg(feature = "cse")] @@ -59,7 +58,7 @@ pub(crate) fn init_hashmap(max_len: Option) -> PlHashMap { pub fn optimize( logical_plan: DslPlan, - opt_state: OptState, + mut opt_state: OptFlags, lp_arena: &mut Arena, expr_arena: &mut Arena, scratch: &mut Vec, @@ -67,40 +66,43 @@ pub fn optimize( ) -> PolarsResult { #[allow(dead_code)] let verbose = verbose(); + + // Gradually fill the rules passed to the optimizer + let opt = StackOptimizer {}; + let mut rules: Vec> = Vec::with_capacity(8); + + // Unset CSE + // This can be turned on again during ir-conversion. + #[allow(clippy::eq_op)] + #[cfg(feature = "cse")] + if opt_state.contains(OptFlags::EAGER) { + opt_state &= !(OptFlags::COMM_SUBEXPR_ELIM | OptFlags::COMM_SUBEXPR_ELIM); + } + let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena, &mut opt_state)?; + // get toggle values - let cluster_with_columns = opt_state.contains(OptState::CLUSTER_WITH_COLUMNS); - let predicate_pushdown = opt_state.contains(OptState::PREDICATE_PUSHDOWN); - let projection_pushdown = opt_state.contains(OptState::PROJECTION_PUSHDOWN); - let type_coercion = opt_state.contains(OptState::TYPE_COERCION); - let simplify_expr = opt_state.contains(OptState::SIMPLIFY_EXPR); - let slice_pushdown = opt_state.contains(OptState::SLICE_PUSHDOWN); - let streaming = opt_state.contains(OptState::STREAMING); - let fast_projection = opt_state.contains(OptState::FAST_PROJECTION); + let cluster_with_columns = opt_state.contains(OptFlags::CLUSTER_WITH_COLUMNS); + let predicate_pushdown = opt_state.contains(OptFlags::PREDICATE_PUSHDOWN); + let projection_pushdown = opt_state.contains(OptFlags::PROJECTION_PUSHDOWN); + let simplify_expr = opt_state.contains(OptFlags::SIMPLIFY_EXPR); + let slice_pushdown = opt_state.contains(OptFlags::SLICE_PUSHDOWN); + let streaming = opt_state.contains(OptFlags::STREAMING); + let fast_projection = opt_state.contains(OptFlags::FAST_PROJECTION); + // Don't run optimizations that don't make sense on a single node. // This keeps eager execution more snappy. - let eager = opt_state.contains(OptState::EAGER); + let eager = opt_state.contains(OptFlags::EAGER); #[cfg(feature = "cse")] - let comm_subplan_elim = opt_state.contains(OptState::COMM_SUBPLAN_ELIM) && !eager; + let comm_subplan_elim = opt_state.contains(OptFlags::COMM_SUBPLAN_ELIM); #[cfg(feature = "cse")] - let comm_subexpr_elim = opt_state.contains(OptState::COMM_SUBEXPR_ELIM) && !eager; + let comm_subexpr_elim = opt_state.contains(OptFlags::COMM_SUBEXPR_ELIM); #[cfg(not(feature = "cse"))] let comm_subexpr_elim = false; #[allow(unused_variables)] - let agg_scan_projection = opt_state.contains(OptState::FILE_CACHING) && !streaming && !eager; - - // Gradually fill the rules passed to the optimizer - let opt = StackOptimizer {}; - let mut rules: Vec> = Vec::with_capacity(8); + let agg_scan_projection = opt_state.contains(OptFlags::FILE_CACHING) && !streaming && !eager; - let mut lp_top = to_alp( - logical_plan, - expr_arena, - lp_arena, - simplify_expr, - type_coercion, - )?; // During debug we check if the optimizations have not modified the final schema. #[cfg(debug_assertions)] let prev_schema = lp_arena.get(lp_top).schema(lp_arena).into_owned(); diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs index 208f2dca0973..6c6d4460b29e 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/group_by.rs @@ -12,7 +12,7 @@ pub(super) fn process_group_by( maintain_order: bool, apply: Option>, options: Arc, - acc_predicates: PlHashMap, ExprIR>, + acc_predicates: PlHashMap, ) -> PolarsResult { use IR::*; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index c787336af375..3b23faef8e04 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -8,7 +8,7 @@ struct LeftRight(T, T); fn should_block_join_specific( ae: &AExpr, how: &JoinType, - on_names: &PlHashSet>, + on_names: &PlHashSet, expr_arena: &Arena, schema_left: &Schema, schema_right: &Schema, @@ -90,11 +90,8 @@ fn all_pred_cols_in_left_on( expr_arena: &mut Arena, left_on: &[ExprIR], ) -> bool { - aexpr_to_leaf_names_iter(predicate.node(), expr_arena).all(|pred_column_name| { - left_on - .iter() - .any(|e| e.output_name() == pred_column_name.as_ref()) - }) + aexpr_to_leaf_names_iter(predicate.node(), expr_arena) + .all(|pred_column_name| left_on.iter().any(|e| e.output_name() == &pred_column_name)) } // Checks if a predicate refers to columns in both tables @@ -130,7 +127,7 @@ pub(super) fn process_join( right_on: Vec, schema: SchemaRef, options: Arc, - acc_predicates: PlHashMap, ExprIR>, + acc_predicates: PlHashMap, ) -> PolarsResult { use IR::*; let schema_left = lp_arena.get(input_left).schema(lp_arena); diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs index a11e2a8f0093..08eb14d2feb4 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/keys.rs @@ -2,29 +2,27 @@ use super::*; // an invisible ascii token we use as delimiter -const HIDDEN_DELIMITER: char = '\u{1D17A}'; +const HIDDEN_DELIMITER: &str = "\u{1D17A}"; /// Determine the hashmap key by combining all the leaf column names of a predicate -pub(super) fn predicate_to_key(predicate: Node, expr_arena: &Arena) -> Arc { +pub(super) fn predicate_to_key(predicate: Node, expr_arena: &Arena) -> PlSmallStr { let mut iter = aexpr_to_leaf_names_iter(predicate, expr_arena); if let Some(first) = iter.next() { if let Some(second) = iter.next() { let mut new = String::with_capacity(32 * iter.size_hint().0); new.push_str(&first); - new.push(HIDDEN_DELIMITER); + new.push_str(HIDDEN_DELIMITER); new.push_str(&second); for name in iter { - new.push(HIDDEN_DELIMITER); + new.push_str(HIDDEN_DELIMITER); new.push_str(&name); } - return Arc::from(new); + return PlSmallStr::from_string(new); } first } else { - let mut s = String::new(); - s.push(HIDDEN_DELIMITER); - Arc::from(s) + PlSmallStr::from_str(HIDDEN_DELIMITER) } } diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index c065b7e3a7cf..7cb0753e5a6d 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -59,7 +59,7 @@ impl<'a> PredicatePushDown<'a> { fn pushdown_and_assign( &self, input: Node, - acc_predicates: PlHashMap, ExprIR>, + acc_predicates: PlHashMap, lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult<()> { @@ -73,7 +73,7 @@ impl<'a> PredicatePushDown<'a> { fn pushdown_and_continue( &self, lp: IR, - mut acc_predicates: PlHashMap, ExprIR>, + mut acc_predicates: PlHashMap, lp_arena: &mut Arena, expr_arena: &mut Arena, has_projections: bool, @@ -133,7 +133,7 @@ impl<'a> PredicatePushDown<'a> { }, e => e, }); - let predicate = to_aexpr(new_expr, expr_arena); + let predicate = to_aexpr(new_expr, expr_arena)?; e.set_node(predicate); } } @@ -188,7 +188,7 @@ impl<'a> PredicatePushDown<'a> { fn no_pushdown_restart_opt( &self, lp: IR, - acc_predicates: PlHashMap, ExprIR>, + acc_predicates: PlHashMap, lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { @@ -219,7 +219,7 @@ impl<'a> PredicatePushDown<'a> { fn no_pushdown( &self, lp: IR, - acc_predicates: PlHashMap, ExprIR>, + acc_predicates: PlHashMap, lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { @@ -243,7 +243,7 @@ impl<'a> PredicatePushDown<'a> { fn push_down( &self, lp: IR, - mut acc_predicates: PlHashMap, ExprIR>, + mut acc_predicates: PlHashMap, lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { @@ -262,7 +262,7 @@ impl<'a> PredicatePushDown<'a> { // // (2) can be pushed past (1) but they both have the same predicate // key name in the hashtable. - let tmp_key = Arc::::from(&*temporary_unique_key(&acc_predicates)); + let tmp_key = temporary_unique_key(&acc_predicates); acc_predicates.insert(tmp_key.clone(), predicate.clone()); let local_predicates = match pushdown_eligibility( @@ -325,7 +325,7 @@ impl<'a> PredicatePushDown<'a> { Ok(lp) }, Scan { - mut paths, + mut sources, file_info, hive_parts: mut scan_hive_parts, ref predicate, @@ -366,6 +366,9 @@ impl<'a> PredicatePushDown<'a> { if let (Some(hive_parts), Some(predicate)) = (&scan_hive_parts, &predicate) { if let Some(io_expr) = self.expr_eval.unwrap()(predicate, expr_arena) { if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { + let paths = sources.as_paths().ok_or_else(|| { + polars_err!(nyi = "Hive partitioning of in-memory buffers") + })?; let mut new_paths = Vec::with_capacity(paths.len()); let mut new_hive_parts = Vec::with_capacity(paths.len()); @@ -400,7 +403,7 @@ impl<'a> PredicatePushDown<'a> { filter: None, }); } else { - paths = Arc::from(new_paths); + sources = ScanSources::Paths(new_paths.into()); scan_hive_parts = Some(Arc::from(new_hive_parts)); } } @@ -422,7 +425,7 @@ impl<'a> PredicatePushDown<'a> { let lp = if do_optimization { Scan { - paths, + sources, file_info, hive_parts, predicate, @@ -432,7 +435,7 @@ impl<'a> PredicatePushDown<'a> { } } else { let lp = Scan { - paths, + sources, file_info, hive_parts, predicate: None, @@ -454,12 +457,12 @@ impl<'a> PredicatePushDown<'a> { if let Some(ref subset) = options.subset { // Predicates on the subset can pass. let subset = subset.clone(); - let mut names_set = PlHashSet::<&str>::with_capacity(subset.len()); + let mut names_set = PlHashSet::::with_capacity(subset.len()); for name in subset.iter() { - names_set.insert(name.as_str()); + names_set.insert(name.clone()); } - let condition = |name: Arc| !names_set.contains(name.as_ref()); + let condition = |name: &PlSmallStr| !names_set.contains(name.as_str()); let local_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); @@ -493,7 +496,7 @@ impl<'a> PredicatePushDown<'a> { MapFunction { ref function, .. } => { if function.allow_predicate_pd() { match function { - FunctionNode::Rename { existing, new, .. } => { + FunctionIR::Rename { existing, new, .. } => { let local_predicates = process_rename(&mut acc_predicates, expr_arena, existing, new)?; let lp = self.pushdown_and_continue( @@ -510,9 +513,8 @@ impl<'a> PredicatePushDown<'a> { expr_arena, )) }, - FunctionNode::Explode { columns, .. } => { - let condition = - |name: Arc| columns.iter().any(|s| s.as_ref() == &*name); + FunctionIR::Explode { columns, .. } => { + let condition = |name: &PlSmallStr| columns.iter().any(|s| s == name); // first columns that refer to the exploded columns should be done here let local_predicates = transfer_to_local_by_name( @@ -535,16 +537,22 @@ impl<'a> PredicatePushDown<'a> { expr_arena, )) }, - FunctionNode::Unpivot { args, .. } => { - let variable_name = args.variable_name.as_deref().unwrap_or("variable"); - let value_name = args.value_name.as_deref().unwrap_or("value"); + #[cfg(feature = "pivot")] + FunctionIR::Unpivot { args, .. } => { + let variable_name = &args + .variable_name + .clone() + .unwrap_or_else(|| PlSmallStr::from_static("variable")); + let value_name = &args + .value_name + .clone() + .unwrap_or_else(|| PlSmallStr::from_static("value")); // predicates that will be done at this level - let condition = |name: Arc| { - let name = &*name; + let condition = |name: &PlSmallStr| { name == variable_name || name == value_name - || args.on.iter().any(|s| s.as_str() == name) + || args.on.iter().any(|s| s == name) }; let local_predicates = transfer_to_local_by_name( expr_arena, @@ -660,8 +668,11 @@ impl<'a> PredicatePushDown<'a> { PythonScan { mut options } => { let predicate = predicate_at_scan(acc_predicates, None, expr_arena); if let Some(predicate) = predicate { - // Only accept streamable expressions as we want to apply the predicates to the batches. - if !is_streamable(predicate.node(), expr_arena, Context::Default) { + // For IO plugins we only accept streamable expressions as + // we want to apply the predicates to the batches. + if !is_streamable(predicate.node(), expr_arena, Context::Default) + && matches!(options.python_source, PythonScanSource::IOPlugin) + { let lp = PythonScan { options }; return Ok(self.optional_apply_predicate( lp, diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs index e094564f4ddc..d31372009d8d 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/rename.rs @@ -1,11 +1,11 @@ -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use super::*; use crate::prelude::optimizer::predicate_pushdown::keys::{key_has_name, predicate_to_key}; fn remove_any_key_referencing_renamed( new: &str, - acc_predicates: &mut PlHashMap, ExprIR>, + acc_predicates: &mut PlHashMap, local_predicates: &mut Vec, ) { let mut move_to_local = vec![]; @@ -21,10 +21,10 @@ fn remove_any_key_referencing_renamed( } pub(super) fn process_rename( - acc_predicates: &mut PlHashMap, ExprIR>, + acc_predicates: &mut PlHashMap, expr_arena: &mut Arena, - existing: &[SmartString], - new: &[SmartString], + existing: &[PlSmallStr], + new: &[PlSmallStr], ) -> PolarsResult> { let mut local_predicates = vec![]; for (existing, new) in existing.iter().zip(new.iter()) { @@ -51,7 +51,7 @@ pub(super) fn process_rename( // This ensure the optimization is pushed down. if let Some(mut e) = acc_predicates.remove(new.as_str()) { let new_node = - rename_matching_aexpr_leaf_names(e.node(), expr_arena, new, existing); + rename_matching_aexpr_leaf_names(e.node(), expr_arena, new, existing.clone()); e.set_node(new_node); acc_predicates.insert(predicate_to_key(new_node, expr_arena), e); } else { diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs index d7480c463b7c..7f14f2269cfd 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/utils.rs @@ -12,7 +12,7 @@ fn combine_by_and(left: Node, right: Node, arena: &mut Arena) -> Node { /// Don't overwrite predicates but combine them. pub(super) fn insert_and_combine_predicate( - acc_predicates: &mut PlHashMap, ExprIR>, + acc_predicates: &mut PlHashMap, predicate: &ExprIR, arena: &mut Arena, ) { @@ -27,7 +27,8 @@ pub(super) fn insert_and_combine_predicate( .or_insert_with(|| predicate.clone()); } -pub(super) fn temporary_unique_key(acc_predicates: &PlHashMap, ExprIR>) -> String { +pub(super) fn temporary_unique_key(acc_predicates: &PlHashMap) -> PlSmallStr { + // TODO: Don't heap allocate during construction. let mut out_key = '\u{1D17A}'.to_string(); let mut existing_keys = acc_predicates.keys(); @@ -35,7 +36,7 @@ pub(super) fn temporary_unique_key(acc_predicates: &PlHashMap, ExprIR>) out_key.push_str(existing_keys.next().unwrap()); } - out_key + PlSmallStr::from_string(out_key) } pub(super) fn combine_predicates(iter: I, arena: &mut Arena) -> ExprIR @@ -59,7 +60,7 @@ where } pub(super) fn predicate_at_scan( - acc_predicates: PlHashMap, ExprIR>, + acc_predicates: PlHashMap, predicate: Option, expr_arena: &mut Arena, ) -> Option { @@ -111,18 +112,18 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) /// transferred to local. pub(super) fn transfer_to_local_by_name( expr_arena: &Arena, - acc_predicates: &mut PlHashMap, ExprIR>, + acc_predicates: &mut PlHashMap, mut condition: F, ) -> Vec where - F: FnMut(Arc) -> bool, + F: FnMut(&PlSmallStr) -> bool, { let mut remove_keys = Vec::with_capacity(acc_predicates.len()); for (key, predicate) in &*acc_predicates { let root_names = aexpr_to_leaf_names(predicate.node(), expr_arena); for name in root_names { - if condition(name) { + if condition(&name) { remove_keys.push(key.clone()); break; } @@ -210,7 +211,7 @@ fn check_and_extend_predicate_pd_nodes( fn get_maybe_aliased_projection_to_input_name_map( e: &ExprIR, expr_arena: &Arena, -) -> Option<(Arc, Arc)> { +) -> Option<(PlSmallStr, PlSmallStr)> { let ae = expr_arena.get(e.node()); match e.get_alias() { Some(alias) => match ae { @@ -227,27 +228,27 @@ fn get_maybe_aliased_projection_to_input_name_map( pub enum PushdownEligibility { Full, // Partial can happen when there are window exprs. - Partial { to_local: Vec> }, + Partial { to_local: Vec }, NoPushdown, } #[allow(clippy::type_complexity)] pub fn pushdown_eligibility( projection_nodes: &[ExprIR], - new_predicates: &[(Arc, ExprIR)], - acc_predicates: &PlHashMap, ExprIR>, + new_predicates: &[(PlSmallStr, ExprIR)], + acc_predicates: &PlHashMap, expr_arena: &mut Arena, -) -> PolarsResult<(PushdownEligibility, PlHashMap, Arc>)> { +) -> PolarsResult<(PushdownEligibility, PlHashMap)> { let mut ae_nodes_stack = Vec::::with_capacity(4); let mut alias_to_col_map = - optimizer::init_hashmap::, Arc>(Some(projection_nodes.len())); + optimizer::init_hashmap::(Some(projection_nodes.len())); let mut col_to_alias_map = alias_to_col_map.clone(); let mut modified_projection_columns = - PlHashSet::>::with_capacity(projection_nodes.len()); + PlHashSet::::with_capacity(projection_nodes.len()); let mut has_window = false; - let mut common_window_inputs = PlHashSet::>::new(); + let mut common_window_inputs = PlHashSet::::new(); // Important: Names inserted into any data structure by this function are // all non-aliased. @@ -255,7 +256,7 @@ pub fn pushdown_eligibility( let process_projection_or_predicate = |ae_nodes_stack: &mut Vec, has_window: &mut bool, - common_window_inputs: &mut PlHashSet>| { + common_window_inputs: &mut PlHashSet| { debug_assert_eq!(ae_nodes_stack.len(), 1); while let Some(node) = ae_nodes_stack.pop() { @@ -276,7 +277,7 @@ pub fn pushdown_eligibility( }; let mut partition_by_names = - PlHashSet::>::with_capacity(partition_by.len()); + PlHashSet::::with_capacity(partition_by.len()); for node in partition_by.iter() { // Only accept col() @@ -333,7 +334,7 @@ pub fn pushdown_eligibility( continue; } - modified_projection_columns.insert(e.output_name_arc().clone()); + modified_projection_columns.insert(e.output_name().clone()); debug_assert!(ae_nodes_stack.is_empty()); ae_nodes_stack.push(e.node()); @@ -349,7 +350,7 @@ pub fn pushdown_eligibility( if has_window && !col_to_alias_map.is_empty() { // Rename to aliased names. - let mut new = PlHashSet::>::with_capacity(2 * common_window_inputs.len()); + let mut new = PlHashSet::::with_capacity(2 * common_window_inputs.len()); for key in common_window_inputs.into_iter() { if let Some(aliased) = col_to_alias_map.get(&key) { @@ -392,7 +393,7 @@ pub fn pushdown_eligibility( } // Note: has_window is constant. - let can_use_column = |col: &Arc| { + let can_use_column = |col: &str| { if has_window { common_window_inputs.contains(col) } else { diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs index 0a9c80827b9e..9b0dc58b8cf0 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/mod.rs @@ -1,5 +1,7 @@ +#[cfg(feature = "pivot")] mod unpivot; +#[cfg(feature = "pivot")] use unpivot::process_unpivot; use super::*; @@ -8,14 +10,14 @@ use super::*; pub(super) fn process_functions( proj_pd: &mut ProjectionPushDown, input: Node, - function: FunctionNode, + function: FunctionIR, mut acc_projections: Vec, - mut projected_names: PlHashSet>, + mut projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - use FunctionNode::*; + use FunctionIR::*; match function { Rename { ref existing, @@ -50,7 +52,12 @@ pub(super) fn process_functions( }, Explode { columns, .. } => { columns.iter().for_each(|name| { - add_str_to_accumulated(name, &mut acc_projections, &mut projected_names, expr_arena) + add_str_to_accumulated( + name.clone(), + &mut acc_projections, + &mut projected_names, + expr_arena, + ) }); proj_pd.pushdown_and_assign( input, @@ -64,6 +71,7 @@ pub(super) fn process_functions( .explode(columns.clone()) .build()) }, + #[cfg(feature = "pivot")] Unpivot { ref args, .. } => { let lp = IR::MapFunction { input, @@ -95,7 +103,7 @@ pub(super) fn process_functions( expr_arena, ) } - let expands_schema = matches!(function, FunctionNode::Unnest { .. }); + let expands_schema = matches!(function, FunctionIR::Unnest { .. }); let local_projections = proj_pd.pushdown_and_assign_check_schema( input, diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs index 70704f76fa9b..518c8e081c5e 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/functions/unpivot.rs @@ -4,7 +4,7 @@ use super::*; pub(super) fn process_unpivot( proj_pd: &mut ProjectionPushDown, lp: IR, - args: &Arc, + args: &Arc, input: Node, acc_projections: Vec, projections_seen: usize, @@ -29,10 +29,20 @@ pub(super) fn process_unpivot( // make sure that the requested columns are projected args.index.iter().for_each(|name| { - add_str_to_accumulated(name, &mut acc_projections, &mut projected_names, expr_arena) + add_str_to_accumulated( + name.clone(), + &mut acc_projections, + &mut projected_names, + expr_arena, + ) }); args.on.iter().for_each(|name| { - add_str_to_accumulated(name, &mut acc_projections, &mut projected_names, expr_arena) + add_str_to_accumulated( + name.clone(), + &mut acc_projections, + &mut projected_names, + expr_arena, + ) }); proj_pd.pushdown_and_assign( diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs index e1326864a283..ee9a60738f22 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/generic.rs @@ -5,7 +5,7 @@ pub(super) fn process_generic( proj_pd: &mut ProjectionPushDown, lp: IR, acc_projections: Vec, - projected_names: PlHashSet>, + projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs index 1ed124ef79ea..1dc1abcc3ba8 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/group_by.rs @@ -11,7 +11,7 @@ pub(super) fn process_group_by( maintain_order: bool, options: Arc, acc_projections: Vec, - projected_names: PlHashSet>, + projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, @@ -49,7 +49,7 @@ pub(super) fn process_group_by( .into_iter() .filter(|agg| { if has_pushed_down && projections_seen > 0 { - projected_names.contains(agg.output_name_arc()) + projected_names.contains(agg.output_name()) } else { true } @@ -68,17 +68,13 @@ pub(super) fn process_group_by( // make sure that the dynamic key is projected #[cfg(feature = "dynamic_group_by")] if let Some(options) = &options.dynamic { - let node = expr_arena.add(AExpr::Column(ColumnName::from( - options.index_column.as_str(), - ))); + let node = expr_arena.add(AExpr::Column(options.index_column.clone())); add_expr_to_accumulated(node, &mut acc_projections, &mut names, expr_arena); } // make sure that the rolling key is projected #[cfg(feature = "dynamic_group_by")] if let Some(options) = &options.rolling { - let node = expr_arena.add(AExpr::Column(ColumnName::from( - options.index_column.as_str(), - ))); + let node = expr_arena.add(AExpr::Column(options.index_column.clone())); add_expr_to_accumulated(node, &mut acc_projections, &mut names, expr_arena); } diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs index 628741511d86..8096b5bde3d8 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/hstack.rs @@ -7,7 +7,7 @@ pub(super) fn process_hstack( mut exprs: Vec, options: ProjectionOptions, mut acc_projections: Vec, - mut projected_names: PlHashSet>, + mut projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs index 007ca07cf206..6eb8bc033015 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/joins.rs @@ -8,11 +8,11 @@ fn add_keys_to_accumulated_state( expr: Node, acc_projections: &mut Vec, local_projection: &mut Vec, - projected_names: &mut PlHashSet>, + projected_names: &mut PlHashSet, expr_arena: &mut Arena, // only for left hand side table we add local names add_local: bool, -) -> Option> { +) -> Option { add_expr_to_accumulated(expr, acc_projections, projected_names, expr_arena); // the projections may do more than simply project. // e.g. col("foo").truncate() * col("bar") @@ -43,7 +43,7 @@ pub(super) fn process_asof_join( right_on: Vec, options: Arc, acc_projections: Vec, - _projected_names: PlHashSet>, + _projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, @@ -78,7 +78,7 @@ pub(super) fn process_asof_join( for name in left_by { let add = _projected_names.contains(name.as_str()); - let node = expr_arena.add(AExpr::Column(ColumnName::from(name.as_str()))); + let node = expr_arena.add(AExpr::Column(name.clone())); add_keys_to_accumulated_state( node, &mut pushdown_left, @@ -89,7 +89,7 @@ pub(super) fn process_asof_join( ); } for name in right_by { - let node = expr_arena.add(AExpr::Column(ColumnName::from(name.as_str()))); + let node = expr_arena.add(AExpr::Column(name.clone())); add_keys_to_accumulated_state( node, &mut pushdown_right, @@ -202,7 +202,7 @@ pub(super) fn process_join( right_on: Vec, mut options: Arc, acc_projections: Vec, - projected_names: PlHashSet>, + projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, @@ -252,7 +252,7 @@ pub(super) fn process_join( // We need the join columns so we push the projection downwards for e in &left_on { - if !local_projected_names.insert(e.output_name_arc().clone()) { + if !local_projected_names.insert(e.output_name().clone()) { continue; } @@ -384,8 +384,8 @@ fn process_projection( proj: ColumnNode, pushdown_left: &mut Vec, pushdown_right: &mut Vec, - names_left: &mut PlHashSet>, - names_right: &mut PlHashSet>, + names_left: &mut PlHashSet, + names_right: &mut PlHashSet, expr_arena: &mut Arena, local_projection: &mut Vec, add_local: bool, @@ -416,16 +416,17 @@ fn process_projection( // Column name of the projection without any alias. let leaf_column_name = column_node_to_name(proj, expr_arena).clone(); - let suffix = options.args.suffix(); + let suffix = options.args.suffix().as_str(); // If _right suffix exists we need to push a projection down without this // suffix. if leaf_column_name.ends_with(suffix) && join_schema.contains(leaf_column_name.as_ref()) { // downwards name is the name without the _right i.e. "foo". let downwards_name = split_suffix(leaf_column_name.as_ref(), suffix); + let downwards_name = PlSmallStr::from_str(downwards_name); - let downwards_name_column = expr_arena.add(AExpr::Column(Arc::from(downwards_name))); + let downwards_name_column = expr_arena.add(AExpr::Column(downwards_name.clone())); // project downwards and locally immediately alias to prevent wrong projections - if names_right.insert(ColumnName::from(downwards_name)) { + if names_right.insert(downwards_name) { pushdown_right.push(ColumnNode(downwards_name_column)); } local_projection.push(proj); @@ -470,7 +471,7 @@ fn resolve_join_suffixes( expr_arena: &mut Arena, local_projection: &[ColumnNode], ) -> PolarsResult { - let suffix = options.args.suffix(); + let suffix = options.args.suffix().as_str(); let alp = IRBuilder::new(input_left, expr_arena, lp_arena) .join(input_right, left_on, right_on, options.clone()) .build(); @@ -482,8 +483,8 @@ fn resolve_join_suffixes( .map(|proj| { let name = column_node_to_name(*proj, expr_arena).clone(); if name.ends_with(suffix) && schema_after_join.get(&name).is_none() { - let downstream_name = &name.as_ref()[..name.len() - suffix.len()]; - let col = AExpr::Column(ColumnName::from(downstream_name)); + let downstream_name = &name.as_str()[..name.len() - suffix.len()]; + let col = AExpr::Column(downstream_name.into()); let node = expr_arena.add(col); all_columns = false; ExprIR::new(node, OutputName::Alias(name.clone())) @@ -496,7 +497,7 @@ fn resolve_join_suffixes( let builder = IRBuilder::from_lp(alp, expr_arena, lp_arena); Ok(if all_columns { builder - .project_simple(projections.iter().map(|e| e.output_name()))? + .project_simple(projections.iter().map(|e| e.output_name().clone()))? .build() } else { builder.project(projections, Default::default()).build() diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs index 0455ba7f5e9c..61c86e789d95 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/mod.rs @@ -29,41 +29,43 @@ use crate::utils::aexpr_to_leaf_names; fn init_vec() -> Vec { Vec::with_capacity(16) } -fn init_set() -> PlHashSet> { +fn init_set() -> PlHashSet { PlHashSet::with_capacity(32) } /// utility function to get names of the columns needed in projection at scan level fn get_scan_columns( - acc_projections: &Vec, + acc_projections: &[ColumnNode], expr_arena: &Arena, row_index: Option<&RowIndex>, file_path_col: Option<&str>, -) -> Option> { - let mut with_columns = None; +) -> Option> { if !acc_projections.is_empty() { - let mut columns = Vec::with_capacity(acc_projections.len()); - for expr in acc_projections { - let name = column_node_to_name(*expr, expr_arena); - // we shouldn't project the row-count column, as that is generated - // in the scan - if let Some(ri) = row_index { - if ri.name.as_ref() == name.as_ref() { - continue; - } - } + Some( + acc_projections + .iter() + .filter_map(|node| { + let name = column_node_to_name(*node, expr_arena); + + if let Some(ri) = row_index { + if ri.name == name { + return None; + } + } - if let Some(file_path_col) = file_path_col { - if file_path_col == name.as_ref() { - continue; - } - } + if let Some(file_path_col) = file_path_col { + if file_path_col == name.as_str() { + return None; + } + } - columns.push((**name).to_owned()) - } - with_columns = Some(Arc::from(columns)); + Some(name.clone()) + }) + .collect::>(), + ) + } else { + None } - with_columns } /// split in a projection vec that can be pushed down and a projection vec that should be used @@ -78,7 +80,7 @@ fn split_acc_projections( down_schema: &Schema, expr_arena: &Arena, expands_schema: bool, -) -> (Vec, Vec, PlHashSet>) { +) -> (Vec, Vec, PlHashSet) { // If node above has as many columns as the projection there is nothing to pushdown. if !expands_schema && down_schema.len() == acc_projections.len() { let local_projections = acc_projections; @@ -100,7 +102,7 @@ fn split_acc_projections( fn add_expr_to_accumulated( expr: Node, acc_projections: &mut Vec, - projected_names: &mut PlHashSet>, + projected_names: &mut PlHashSet, expr_arena: &Arena, ) { for root_node in aexpr_to_column_nodes_iter(expr, expr_arena) { @@ -112,14 +114,14 @@ fn add_expr_to_accumulated( } fn add_str_to_accumulated( - name: &str, + name: PlSmallStr, acc_projections: &mut Vec, - projected_names: &mut PlHashSet>, + projected_names: &mut PlHashSet, expr_arena: &mut Arena, ) { // if empty: all columns are already projected. - if !acc_projections.is_empty() && !projected_names.contains(name) { - let node = expr_arena.add(AExpr::Column(ColumnName::from(name))); + if !acc_projections.is_empty() && !projected_names.contains(&name) { + let node = expr_arena.add(AExpr::Column(name)); add_expr_to_accumulated(node, acc_projections, projected_names, expr_arena); } } @@ -225,8 +227,8 @@ impl ProjectionPushDown { proj: ColumnNode, pushdown_left: &mut Vec, pushdown_right: &mut Vec, - names_left: &mut PlHashSet>, - names_right: &mut PlHashSet>, + names_left: &mut PlHashSet, + names_right: &mut PlHashSet, expr_arena: &Arena, ) -> (bool, bool) { let mut pushed_at_least_one = false; @@ -257,7 +259,7 @@ impl ProjectionPushDown { &mut self, input: Node, acc_projections: Vec, - names: PlHashSet>, + names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, @@ -323,7 +325,7 @@ impl ProjectionPushDown { &mut self, logical_plan: IR, mut acc_projections: Vec, - mut projected_names: PlHashSet>, + mut projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, @@ -344,7 +346,7 @@ impl ProjectionPushDown { expr_arena, ), SimpleProjection { columns, input, .. } => { - let exprs = names_to_expr_irs(columns.iter_names(), expr_arena); + let exprs = names_to_expr_irs(columns.iter_names_cloned(), expr_arena); process_projection( self, input, @@ -396,7 +398,7 @@ impl ProjectionPushDown { Ok(PythonScan { options }) }, Scan { - paths, + sources, mut file_info, mut hive_parts, scan_type, @@ -508,7 +510,7 @@ impl ProjectionPushDown { } }; let lp = Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -563,7 +565,7 @@ impl ProjectionPushDown { if let Some(subset) = options.subset.as_ref() { subset.iter().for_each(|name| { add_str_to_accumulated( - name, + name.clone(), &mut acc_projections, &mut projected_names, expr_arena, @@ -574,7 +576,7 @@ impl ProjectionPushDown { let input_schema = lp_arena.get(input).schema(lp_arena); for name in input_schema.iter_names() { add_str_to_accumulated( - name.as_str(), + name.clone(), &mut acc_projections, &mut projected_names, expr_arena, diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs index 4fda3a2432bc..6b1106a7ca19 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/projection.rs @@ -18,7 +18,7 @@ fn check_double_projection( expr: &ExprIR, expr_arena: &mut Arena, acc_projections: &mut Vec, - projected_names: &mut PlHashSet>, + projected_names: &mut PlHashSet, ) { // Factor out the pruning function fn prune_projections_by_name( @@ -26,7 +26,7 @@ fn check_double_projection( name: &str, expr_arena: &Arena, ) { - acc_projections.retain(|node| column_node_to_name(*node, expr_arena).as_ref() != name); + acc_projections.retain(|node| column_node_to_name(*node, expr_arena) != name); } if let Some(name) = expr.get_non_projected_name() { if projected_names.remove(name) { @@ -50,7 +50,7 @@ pub(super) fn process_projection( input: Node, mut exprs: Vec, mut acc_projections: Vec, - mut projected_names: PlHashSet>, + mut projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, @@ -70,7 +70,7 @@ pub(super) fn process_projection( // simply select the last column // NOTE: the first can be the inserted index column, so that might not work let (first_name, _) = input_schema.try_get_at_index(input_schema.len() - 1)?; - let expr = expr_arena.add(AExpr::Column(ColumnName::from(first_name.as_str()))); + let expr = expr_arena.add(AExpr::Column(first_name.clone())); if !acc_projections.is_empty() { check_double_projection( &exprs[0], @@ -97,7 +97,7 @@ pub(super) fn process_projection( for e in exprs { if has_pushed_down { // remove projections that are not used upstream - if !projected_names.contains(e.output_name_arc()) { + if !projected_names.contains(e.output_name()) { continue; } diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs index 3f0a39d05a7b..37142ba90943 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/rename.rs @@ -1,6 +1,6 @@ use std::collections::BTreeSet; -use smartstring::alias::String as SmartString; +use polars_utils::pl_str::PlSmallStr; use super::*; @@ -15,8 +15,8 @@ fn iter_and_update_nodes( let node = column_node.0; if !processed.contains(&node.0) { // We walk the query backwards, so we rename new to existing - if column_node_to_name(*column_node, expr_arena).as_ref() == new { - let new_node = expr_arena.add(AExpr::Column(ColumnName::from(existing))); + if column_node_to_name(*column_node, expr_arena) == new { + let new_node = expr_arena.add(AExpr::Column(PlSmallStr::from_str(existing))); *column_node = ColumnNode(new_node); processed.insert(new_node.0); } @@ -27,28 +27,24 @@ fn iter_and_update_nodes( #[allow(clippy::too_many_arguments)] pub(super) fn process_rename( acc_projections: &mut [ColumnNode], - projected_names: &mut PlHashSet>, + projected_names: &mut PlHashSet, expr_arena: &mut Arena, - existing: &[SmartString], - new: &[SmartString], + existing: &[PlSmallStr], + new: &[PlSmallStr], swapping: bool, ) -> PolarsResult<()> { if swapping { - let reverse_map: PlHashMap<_, _> = new - .iter() - .map(|s| s.as_str()) - .zip(existing.iter().map(|s| s.as_str())) - .collect(); + let reverse_map: PlHashMap<_, _> = + new.iter().cloned().zip(existing.iter().cloned()).collect(); let mut new_projected_names = PlHashSet::with_capacity(projected_names.len()); for col in acc_projections { let name = column_node_to_name(*col, expr_arena); - if let Some(previous) = reverse_map.get(name.as_ref()) { - let previous: Arc = Arc::from(*previous); + if let Some(previous) = reverse_map.get(name) { let new = expr_arena.add(AExpr::Column(previous.clone())); *col = ColumnNode(new); - let _ = new_projected_names.insert(previous); + let _ = new_projected_names.insert(previous.clone()); } else { let _ = new_projected_names.insert(name.clone()); } @@ -58,7 +54,7 @@ pub(super) fn process_rename( let mut processed = BTreeSet::new(); for (existing, new) in existing.iter().zip(new.iter()) { if projected_names.remove(new.as_str()) { - let name: Arc = ColumnName::from(existing.as_str()); + let name = existing.clone(); projected_names.insert(name); iter_and_update_nodes(existing, new, acc_projections, expr_arena, &mut processed); } diff --git a/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs b/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs index 6b0863fa11cc..2cdb1edab260 100644 --- a/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs +++ b/crates/polars-plan/src/plans/optimizer/projection_pushdown/semi_anti_join.rs @@ -9,7 +9,7 @@ pub(super) fn process_semi_anti_join( right_on: Vec, options: Arc, acc_projections: Vec, - _projected_names: PlHashSet>, + _projected_names: PlHashSet, projections_seen: usize, lp_arena: &mut Arena, expr_arena: &mut Arena, diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs similarity index 87% rename from crates/polars-plan/src/plans/optimizer/simplify_expr.rs rename to crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs index 86cbcc0e5e82..1df68a0adcfa 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/mod.rs @@ -1,8 +1,24 @@ +mod simplify_functions; + use polars_utils::floor_divmod::FloorDivMod; use polars_utils::total_ord::ToTotalOrd; +use simplify_functions::optimize_functions; use crate::plans::*; -use crate::prelude::optimizer::simplify_functions::optimize_functions; + +fn new_null_count(input: &[ExprIR]) -> AExpr { + AExpr::Function { + input: input.to_vec(), + function: FunctionExpr::NullCount, + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "", + cast_to_supertypes: None, + check_lengths: UnsafeBool::default(), + flags: FunctionFlags::ALLOW_GROUP_AWARE | FunctionFlags::RETURNS_SCALAR, + }, + } +} macro_rules! eval_binary_same_type { ($lhs:expr, $rhs:expr, |$l: ident, $r: ident| $ret: expr) => {{ @@ -407,7 +423,7 @@ fn string_addition_to_linear_concat( _ => Some(AExpr::Function { input: vec![left_e, right_e], function: StringFunction::ConcatHorizontal { - delimiter: "".to_string(), + delimiter: "".into(), ignore_nulls: false, } .into(), @@ -440,6 +456,75 @@ impl OptimizationRule for SimplifyExprRule { let expr = expr_arena.get(expr_node).clone(); let out = match &expr { + // drop_nulls().len() -> len() - null_count() + // drop_nulls().count() -> len() - null_count() + AExpr::Agg(IRAggExpr::Count(input, _)) => { + let input_expr = expr_arena.get(*input); + match input_expr { + AExpr::Function { + input, + function: FunctionExpr::DropNulls, + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let drop_nulls_input_node = input[0].node(); + match expr_arena.get(drop_nulls_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Minus, + right: expr_arena.add(new_null_count(input)), + left: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + drop_nulls_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // is_null().sum() -> null_count() + // is_not_null().sum() -> len() - null_count() + AExpr::Agg(IRAggExpr::Sum(input)) => { + let input_expr = expr_arena.get(*input); + match input_expr { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => Some(new_null_count(input)), + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_not_null_input_node = input[0].node(); + match expr_arena.get(is_not_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Minus, + right: expr_arena.add(new_null_count(input)), + left: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + is_not_null_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, // lit(left) + lit(right) => lit(left + right) // and null propagation AExpr::BinaryExpr { left, op, right } => { @@ -631,7 +716,7 @@ fn test_expr_to_aexp() { let expr = Expr::Literal(LiteralValue::Int8(0)); let mut arena = Arena::new(); - let aexpr = to_aexpr(expr, &mut arena); + let aexpr = to_aexpr(expr, &mut arena).unwrap(); assert_eq!(aexpr, Node(0)); assert!(matches!( arena.get(aexpr), diff --git a/crates/polars-plan/src/plans/optimizer/simplify_functions.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs similarity index 74% rename from crates/polars-plan/src/plans/optimizer/simplify_functions.rs rename to crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs index 504af2e517f9..2b5493c62e6b 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr/simplify_functions.rs @@ -7,6 +7,87 @@ pub(super) fn optimize_functions( expr_arena: &mut Arena, ) -> PolarsResult> { let out = match function { + // is_null().any() -> null_count() > 0 + // is_not_null().any() -> null_count() < len() + // CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS + FunctionExpr::Boolean(BooleanFunction::Any { ignore_nulls: _ }) => { + let input_node = expr_arena.get(input[0].node()); + match input_node { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(new_null_count(input)), + op: Operator::Gt, + right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))), + }), + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_not_null_input_node = input[0].node(); + match expr_arena.get(is_not_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Lt, + left: expr_arena.add(new_null_count(input)), + right: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + is_not_null_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // is_null().all() -> null_count() == len() + // is_not_null().all() -> null_count() == 0 + FunctionExpr::Boolean(BooleanFunction::All { ignore_nulls: _ }) => { + let input_node = expr_arena.get(input[0].node()); + match input_node { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_null_input_node = input[0].node(); + match expr_arena.get(is_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Eq, + right: expr_arena.add(new_null_count(input)), + left: expr_arena + .add(AExpr::Agg(IRAggExpr::Count(is_null_input_node, true))), + }), + _ => None, + } + } else { + None + } + }, + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(new_null_count(input)), + op: Operator::Eq, + right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))), + }), + _ => None, + } + }, // sort().reverse() -> sort(reverse) // sort_by().reverse() -> sort_by(reverse) FunctionExpr::Reverse => { diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index 34ab66e9499c..b656795f53d2 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -165,7 +165,7 @@ impl SlicePushDown { } #[cfg(feature = "csv")] (Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -176,7 +176,7 @@ impl SlicePushDown { file_options.slice = Some((0, state.offset as usize + state.len as usize)); let lp = Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -189,7 +189,7 @@ impl SlicePushDown { }, #[cfg(feature = "parquet")] (Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -200,7 +200,7 @@ impl SlicePushDown { file_options.slice = Some((state.offset, state.len as usize)); let lp = Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -213,7 +213,7 @@ impl SlicePushDown { }, // TODO! we currently skip slice pushdown if there is a predicate. (Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -224,7 +224,7 @@ impl SlicePushDown { options.slice = Some((0, state.len as usize)); let lp = Scan { - paths, + sources, file_info, hive_parts, output_schema, @@ -385,8 +385,7 @@ impl SlicePushDown { // other blocking nodes | m @ (DataFrameScan {..}, _) | m @ (Sort {..}, _) - | m @ (MapFunction {function: FunctionNode::Explode {..}, ..}, _) - | m @ (MapFunction {function: FunctionNode::Unpivot {..}, ..}, _) + | m @ (MapFunction {function: FunctionIR::Explode {..}, ..}, _) | m @ (Cache {..}, _) | m @ (Distinct {..}, _) | m @ (GroupBy{..},_) @@ -395,7 +394,12 @@ impl SlicePushDown { => { let (lp, state) = m; self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) - } + }, + #[cfg(feature = "pivot")] + m @ (MapFunction {function: FunctionIR::Unpivot {..}, ..}, _) => { + let (lp, state) = m; + self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) + }, // [Pushdown] (MapFunction {input, function}, _) if function.allow_predicate_pd() => { let lp = MapFunction {input, function}; @@ -409,7 +413,8 @@ impl SlicePushDown { // [Pushdown] // these nodes will be pushed down. // State is None, we can continue - m @(Select {..}, None) + m @(Select {..}, None) | + m @ (SimpleProjection {..}, _) => { let (lp, state) = m; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) @@ -427,14 +432,14 @@ impl SlicePushDown { } } (HStack {input, exprs, schema, options}, _) => { - let check = can_pushdown_slice_past_projections(&exprs, expr_arena); + let (can_pushdown, all_elementwise_and_any_expr_has_column) = can_pushdown_slice_past_projections(&exprs, expr_arena); if ( - // If the schema length is greater then an input column is being projected, so + // If the schema length is greater than an input column is being projected, so // the exprs in with_columns do not need to have an input column name. - schema.len() > exprs.len() && check.0 + schema.len() > exprs.len() && can_pushdown ) - || check.1 // e.g. select(c).with_columns(c = c + 1) + || all_elementwise_and_any_expr_has_column // e.g. select(c).with_columns(c = c + 1) { let lp = HStack {input, exprs, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index 3528387c06f0..078acbae7177 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -19,7 +19,8 @@ use polars_time::{DynamicGroupOptions, RollingGroupOptions}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::plans::ExprIR; +use crate::dsl::Selector; +use crate::plans::{ExprIR, PlSmallStr}; #[cfg(feature = "python")] use crate::prelude::python_udf::PythonFunction; @@ -30,14 +31,14 @@ pub type FileCount = u32; /// Generic options for all file types. pub struct FileScanOptions { pub slice: Option<(i64, usize)>, - pub with_columns: Option>, + pub with_columns: Option>, pub cache: bool, pub row_index: Option, pub rechunk: bool, pub file_counter: FileCount, pub hive_options: HiveOptions, pub glob: bool, - pub include_file_paths: Option>, + pub include_file_paths: Option, } #[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)] @@ -71,9 +72,23 @@ pub struct GroupbyOptions { #[derive(Clone, Debug, Eq, PartialEq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct DistinctOptions { +pub struct DistinctOptionsDSL { /// Subset of columns that will be taken into account. - pub subset: Option>>, + pub subset: Option>, + /// This will maintain the order of the input. + /// Note that this is more expensive. + /// `maintain_order` is not supported in the streaming + /// engine. + pub maintain_order: bool, + /// Which rows to keep. + pub keep_strategy: UniqueKeepStrategy, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] +pub struct DistinctOptionsIR { + /// Subset of columns that will be taken into account. + pub subset: Option>, /// This will maintain the order of the input. /// Note that this is more expensive. /// `maintain_order` is not supported in the streaming @@ -199,6 +214,13 @@ impl FunctionOptions { pub fn check_lengths(&self) -> bool { self.check_lengths.0 } + + pub fn is_elementwise(&self) -> bool { + self.collect_groups == ApplyOptions::ElementWise + && !self + .flags + .contains(FunctionFlags::CHANGES_LENGTH | FunctionFlags::RETURNS_SCALAR) + } } impl Default for FunctionOptions { @@ -235,7 +257,7 @@ pub struct PythonOptions { /// Schema the reader will produce when the file is read. pub output_schema: Option, // Projected column names. - pub with_columns: Option>, + pub with_columns: Option>, // Which interface is the python function. pub python_source: PythonScanSource, /// Optional predicate the reader must apply. diff --git a/crates/polars-plan/src/plans/python/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs index 1232fcfde673..abd018b3a4e6 100644 --- a/crates/polars-plan/src/plans/python/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -37,7 +37,7 @@ pub fn predicate_to_pa( None } }, - AExpr::Column(name) => Some(format!("pa.compute.field('{}')", name.as_ref())), + AExpr::Column(name) => Some(format!("pa.compute.field('{}')", name)), AExpr::Literal(LiteralValue::Series(s)) => { if !args.allow_literal_series || s.is_empty() || s.len() > 100 { None diff --git a/crates/polars-plan/src/plans/schema.rs b/crates/polars-plan/src/plans/schema.rs index 96e6c1b0d2c2..ac129afc3703 100644 --- a/crates/polars-plan/src/plans/schema.rs +++ b/crates/polars-plan/src/plans/schema.rs @@ -4,7 +4,7 @@ use std::sync::Mutex; use arrow::datatypes::ArrowSchemaRef; use either::Either; use polars_core::prelude::*; -use polars_utils::format_smartstring; +use polars_utils::format_pl_smallstr; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -18,7 +18,12 @@ impl DslPlan { pub fn compute_schema(&self) -> PolarsResult { let mut lp_arena = Default::default(); let mut expr_arena = Default::default(); - let node = to_alp(self.clone(), &mut expr_arena, &mut lp_arena, false, true)?; + let node = to_alp( + self.clone(), + &mut expr_arena, + &mut lp_arena, + &mut OptFlags::schema_only(), + )?; Ok(lp_arena.get(node).schema(&lp_arena).into_owned()) } @@ -56,7 +61,7 @@ impl FileInfo { for field in hive_schema.iter_fields() { if let Ok(existing) = schema.try_get_mut(&field.name) { - *existing = field.data_type().clone(); + *existing = field.dtype().clone(); } else { schema .insert_at_index(schema.len(), field.name, field.dtype.clone()) @@ -281,7 +286,7 @@ pub(crate) fn det_join_schema( { let left_is_removed = join_on_left.contains(name.as_str()) && should_coalesce; if schema_left.contains(name.as_str()) && !left_is_removed { - let new_name = format_smartstring!("{}{}", name, options.args.suffix()); + let new_name = format_pl_smallstr!("{}{}", name, options.args.suffix()); new_schema.with_column(new_name, dtype.clone()); } else { new_schema.with_column(name.clone(), dtype.clone()); @@ -314,7 +319,7 @@ pub(crate) fn det_join_schema( if should_coalesce && field_left.name != field_right.name { if schema_left.contains(&field_right.name) { new_schema.with_column( - _join_suffix_name(&field_right.name, options.args.suffix()).into(), + _join_suffix_name(&field_right.name, options.args.suffix()), field_right.dtype, ); } else { @@ -346,7 +351,7 @@ pub(crate) fn det_join_schema( // The names that are joined on are merged if schema_left.contains(name.as_str()) { - let new_name = format_smartstring!("{}{}", name, options.args.suffix()); + let new_name = format_pl_smallstr!("{}{}", name, options.args.suffix()); new_schema.with_column(new_name, dtype.clone()); } else { new_schema.with_column(name.clone(), dtype.clone()); diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index 39fdd659077d..2f5fce9bc283 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -53,7 +53,7 @@ impl TreeWalker for Expr { BinaryExpr { left, op, right } => { BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?} }, - Cast { expr, data_type, options: strict } => Cast { expr: am(expr, f)?, data_type, options: strict }, + Cast { expr, dtype, options: strict } => Cast { expr: am(expr, f)?, dtype, options: strict }, Sort { expr, options } => Sort { expr: am(expr, f)?, options }, Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar }, SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, sort_options }, @@ -166,17 +166,16 @@ impl AExpr { (Alias(_, l), Alias(_, r)) => l == r, (Column(l), Column(r)) => l == r, (Literal(l), Literal(r)) => l == r, - (Nth(l), Nth(r)) => l == r, (Window { options: l, .. }, Window { options: r, .. }) => l == r, ( Cast { options: strict_l, - data_type: dtl, + dtype: dtl, .. }, Cast { options: strict_r, - data_type: dtr, + dtype: dtr, .. }, ) => strict_l == strict_r && dtl == dtr, diff --git a/crates/polars-plan/src/plans/visitor/hash.rs b/crates/polars-plan/src/plans/visitor/hash.rs index 80c251108297..7087122802ea 100644 --- a/crates/polars-plan/src/plans/visitor/hash.rs +++ b/crates/polars-plan/src/plans/visitor/hash.rs @@ -74,7 +74,7 @@ impl Hash for HashableEqLP<'_> { predicate.traverse_and_hash(self.expr_arena, state); }, IR::Scan { - paths, + sources, file_info: _, hive_parts: _, predicate, @@ -84,7 +84,7 @@ impl Hash for HashableEqLP<'_> { } => { // We don't have to traverse the schema, hive partitions etc. as they are derivative from the paths. scan_type.hash(state); - paths.hash(state); + sources.hash(state); hash_option_expr(predicate, self.expr_arena, state); file_options.hash(state); }, @@ -254,7 +254,7 @@ impl HashableEqLP<'_> { ) => expr_ir_eq(l, r, self.expr_arena), ( IR::Scan { - paths: pl, + sources: pl, file_info: _, hive_parts: _, predicate: pred_l, @@ -263,7 +263,7 @@ impl HashableEqLP<'_> { file_options: ol, }, IR::Scan { - paths: pr, + sources: pr, file_info: _, hive_parts: _, predicate: pred_r, @@ -272,7 +272,7 @@ impl HashableEqLP<'_> { file_options: or, }, ) => { - pl == pr + pl.as_paths() == pr.as_paths() && stl == str && ol == or && opt_expr_ir_eq(pred_l, pred_r, self.expr_arena) diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index cd7e6c3e0c7e..bf18cc4119d2 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -3,16 +3,18 @@ use std::iter::FlatMap; use polars_core::prelude::*; use polars_utils::idx_vec::UnitVec; -use smartstring::alias::String as SmartString; -use crate::constants::{get_len_name, LEN}; +use crate::constants::get_len_name; use crate::prelude::*; /// Utility to write comma delimited strings -pub fn comma_delimited(mut s: String, items: &[SmartString]) -> String { +pub fn comma_delimited(mut s: String, items: &[S]) -> String +where + S: AsRef, +{ s.push('('); for c in items { - s.push_str(c); + s.push_str(c.as_ref()); s.push_str(", "); } s.pop(); @@ -135,7 +137,7 @@ pub fn has_null(current_expr: &Expr) -> bool { }) } -pub fn aexpr_output_name(node: Node, arena: &Arena) -> PolarsResult> { +pub fn aexpr_output_name(node: Node, arena: &Arena) -> PolarsResult { for (_, ae) in arena.iter(node) { match ae { // don't follow the partition by branch @@ -143,7 +145,7 @@ pub fn aexpr_output_name(node: Node, arena: &Arena) -> PolarsResult return Ok(name.clone()), AExpr::Alias(_, name) => return Ok(name.clone()), AExpr::Len => return Ok(get_len_name()), - AExpr::Literal(val) => return Ok(val.output_column_name()), + AExpr::Literal(val) => return Ok(val.output_column_name().clone()), _ => {}, } } @@ -155,7 +157,7 @@ pub fn aexpr_output_name(node: Node, arena: &Arena) -> PolarsResult PolarsResult> { +pub fn expr_output_name(expr: &Expr) -> PolarsResult { for e in expr { match e { // don't follow the partition by branch @@ -171,7 +173,7 @@ pub fn expr_output_name(expr: &Expr) -> PolarsResult> { "this expression may produce multiple output names" ), Expr::Len => return Ok(get_len_name()), - Expr::Literal(val) => return Ok(val.output_column_name()), + Expr::Literal(val) => return Ok(val.output_column_name().clone()), _ => {}, } } @@ -183,7 +185,7 @@ pub fn expr_output_name(expr: &Expr) -> PolarsResult> { /// This function should be used to find the name of the start of an expression /// Normal iteration would just return the first root column it found -pub(crate) fn get_single_leaf(expr: &Expr) -> PolarsResult> { +pub(crate) fn get_single_leaf(expr: &Expr) -> PolarsResult { for e in expr { match e { Expr::Filter { input, .. } => return get_single_leaf(input), @@ -191,7 +193,7 @@ pub(crate) fn get_single_leaf(expr: &Expr) -> PolarsResult> { Expr::SortBy { expr, .. } => return get_single_leaf(expr), Expr::Window { function, .. } => return get_single_leaf(function), Expr::Column(name) => return Ok(name.clone()), - Expr::Len => return Ok(ColumnName::from(LEN)), + Expr::Len => return Ok(get_len_name()), _ => {}, } } @@ -201,17 +203,17 @@ pub(crate) fn get_single_leaf(expr: &Expr) -> PolarsResult> { } #[allow(clippy::type_complexity)] -pub fn expr_to_leaf_column_names_iter(expr: &Expr) -> impl Iterator> + '_ { +pub fn expr_to_leaf_column_names_iter(expr: &Expr) -> impl Iterator + '_ { expr_to_leaf_column_exprs_iter(expr).flat_map(|e| expr_to_leaf_column_name(e).ok()) } /// This should gradually replace expr_to_root_column as this will get all names in the tree. -pub fn expr_to_leaf_column_names(expr: &Expr) -> Vec> { +pub fn expr_to_leaf_column_names(expr: &Expr) -> Vec { expr_to_leaf_column_names_iter(expr).collect() } /// unpack alias(col) to name of the root column name -pub fn expr_to_leaf_column_name(expr: &Expr) -> PolarsResult> { +pub fn expr_to_leaf_column_name(expr: &Expr) -> PolarsResult { let mut leaves = expr_to_leaf_column_exprs_iter(expr).collect::>(); polars_ensure!(leaves.len() <= 1, ComputeError: "found more than one root column name"); match leaves.pop() { @@ -240,7 +242,7 @@ pub(crate) fn aexpr_to_column_nodes_iter<'a>( }) } -pub fn column_node_to_name(node: ColumnNode, arena: &Arena) -> &Arc { +pub fn column_node_to_name(node: ColumnNode, arena: &Arena) -> &PlSmallStr { if let AExpr::Column(name) = arena.get(node.0) { name } else { @@ -254,7 +256,7 @@ pub(crate) fn rename_matching_aexpr_leaf_names( node: Node, arena: &mut Arena, current: &str, - new_name: &str, + new_name: PlSmallStr, ) -> Node { let mut leaves = aexpr_to_column_nodes_iter(node, arena); @@ -262,10 +264,10 @@ pub(crate) fn rename_matching_aexpr_leaf_names( // we convert to expression as we cannot easily copy the aexpr. let mut new_expr = node_to_expr(node, arena); new_expr = new_expr.map_expr(|e| match e { - Expr::Column(name) if &*name == current => Expr::Column(ColumnName::from(new_name)), + Expr::Column(name) if &*name == current => Expr::Column(new_name.clone()), e => e, }); - to_aexpr(new_expr, arena) + to_aexpr(new_expr, arena).expect("infallible") } else { node } @@ -294,18 +296,18 @@ pub fn expressions_to_schema( pub fn aexpr_to_leaf_names_iter( node: Node, arena: &Arena, -) -> impl Iterator> + '_ { +) -> impl Iterator + '_ { aexpr_to_column_nodes_iter(node, arena).map(|node| match arena.get(node.0) { AExpr::Column(name) => name.clone(), _ => unreachable!(), }) } -pub fn aexpr_to_leaf_names(node: Node, arena: &Arena) -> Vec> { +pub fn aexpr_to_leaf_names(node: Node, arena: &Arena) -> Vec { aexpr_to_leaf_names_iter(node, arena).collect() } -pub fn aexpr_to_leaf_name(node: Node, arena: &Arena) -> Arc { +pub fn aexpr_to_leaf_name(node: Node, arena: &Arena) -> PlSmallStr { aexpr_to_leaf_names_iter(node, arena).next().unwrap() } @@ -358,7 +360,7 @@ pub(crate) fn expr_irs_to_schema, K: AsRef>( let mut field = arena.get(e.node()).to_field(schema, ctxt, arena).unwrap(); if let Some(name) = e.get_alias() { - field.name = name.as_ref().into() + field.name = name.clone() } field }) diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml new file mode 100644 index 000000000000..b93d34a678e5 --- /dev/null +++ b/crates/polars-python/Cargo.toml @@ -0,0 +1,258 @@ +[package] +name = "polars-python" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Enable running Polars workloads in Python" + +[dependencies] +polars-core = { workspace = true, features = ["python"] } +polars-error = { workspace = true } +polars-io = { workspace = true } +polars-lazy = { workspace = true, features = ["python"] } +polars-ops = { workspace = true } +polars-parquet = { workspace = true, optional = true } +polars-plan = { workspace = true } +polars-time = { workspace = true } +polars-utils = { workspace = true } + +# TODO! remove this once truly activated. This is required to make sdist building work +polars-stream = { workspace = true } + +ahash = { workspace = true } +arboard = { workspace = true, optional = true } +bytemuck = { workspace = true } +bytes = { workspace = true } +ciborium = { workspace = true } +either = { workspace = true } +itoa = { workspace = true } +libc = { workspace = true } +ndarray = { workspace = true } +num-traits = { workspace = true } +# TODO: Pin to released version once NumPy 2.0 support is merged +# https://github.com/PyO3/rust-numpy/issues/409 +numpy = { git = "https://github.com/stinodego/rust-numpy.git", rev = "9ba9962ae57ba26e35babdce6f179edf5fe5b9c8", default-features = false } +once_cell = { workspace = true } +pyo3 = { workspace = true, features = ["abi3-py38", "chrono", "multiple-pymethods"] } +recursive = { workspace = true } +serde_json = { workspace = true, optional = true } +thiserror = { workspace = true } + +[dependencies.polars] +workspace = true +features = [ + "abs", + "approx_unique", + "array_any_all", + "arg_where", + "business", + "concat_str", + "cum_agg", + "cumulative_eval", + "dataframe_arithmetic", + "month_start", + "month_end", + "offset_by", + "diagonal_concat", + "diff", + "dot_diagram", + "dot_product", + "dtype-categorical", + "dtype-full", + "dynamic_group_by", + "ewma", + "ewma_by", + "fmt", + "fused", + "interpolate", + "interpolate_by", + "is_first_distinct", + "is_last_distinct", + "is_unique", + "is_between", + "lazy", + "list_eval", + "list_to_struct", + "array_to_struct", + "log", + "mode", + "moment", + "ndarray", + "partition_by", + "product", + "random", + "range", + "rank", + "reinterpret", + "replace", + "rolling_window", + "rolling_window_by", + "round_series", + "row_hash", + "rows", + "semi_anti_join", + "serde-lazy", + "string_encoding", + "string_reverse", + "string_to_integer", + "string_pad", + "strings", + "temporal", + "to_dummies", + "true_div", + "unique_counts", + "zip_with", + "cov", +] + +[build-dependencies] +version_check = { workspace = true } + +[features] +# Features below are only there to enable building a slim binary during development. +avro = ["polars/avro"] +parquet = ["polars/parquet", "polars-parquet"] +ipc = ["polars/ipc"] +ipc_streaming = ["polars/ipc_streaming"] +is_in = ["polars/is_in"] +json = ["polars/serde", "serde_json", "polars/json", "polars-utils/serde"] +trigonometry = ["polars/trigonometry"] +sign = ["polars/sign"] +asof_join = ["polars/asof_join"] +cross_join = ["polars/cross_join"] +pct_change = ["polars/pct_change"] +repeat_by = ["polars/repeat_by"] + +streaming = ["polars/streaming"] +meta = ["polars/meta"] +search_sorted = ["polars/search_sorted"] +decompress = ["polars/decompress-fast"] +regex = ["polars/regex"] +csv = ["polars/csv"] +clipboard = ["arboard"] +extract_jsonpath = ["polars/extract_jsonpath"] +pivot = ["polars/pivot"] +top_k = ["polars/top_k"] +propagate_nans = ["polars/propagate_nans"] +sql = ["polars/sql"] +performant = ["polars/performant"] +timezones = ["polars/timezones"] +cse = ["polars/cse"] +merge_sorted = ["polars/merge_sorted"] +list_gather = ["polars/list_gather"] +list_count = ["polars/list_count"] +array_count = ["polars/array_count", "polars/dtype-array"] +binary_encoding = ["polars/binary_encoding"] +list_sets = ["polars-lazy/list_sets"] +list_any_all = ["polars/list_any_all"] +array_any_all = ["polars/array_any_all", "polars/dtype-array"] +list_drop_nulls = ["polars/list_drop_nulls"] +list_sample = ["polars/list_sample"] +cutqcut = ["polars/cutqcut"] +rle = ["polars/rle"] +extract_groups = ["polars/extract_groups"] +ffi_plugin = ["polars-plan/ffi_plugin"] +cloud = ["polars/cloud", "polars/aws", "polars/gcp", "polars/azure", "polars/http"] +peaks = ["polars/peaks"] +hist = ["polars/hist"] +find_many = ["polars/find_many"] +new_streaming = ["polars-lazy/new_streaming"] + +dtype-i8 = [] +dtype-i16 = [] +dtype-u8 = [] +dtype-u16 = [] +dtype-array = [] +object = ["polars/object"] + +dtypes = [ + "dtype-array", + "dtype-i16", + "dtype-i8", + "dtype-u16", + "dtype-u8", + "object", +] + +operations = [ + "array_any_all", + "array_count", + "is_in", + "repeat_by", + "trigonometry", + "sign", + "performant", + "list_gather", + "list_count", + "list_sets", + "list_any_all", + "list_drop_nulls", + "list_sample", + "cutqcut", + "rle", + "extract_groups", + "pivot", + "extract_jsonpath", + "asof_join", + "cross_join", + "pct_change", + "search_sorted", + "merge_sorted", + "top_k", + "propagate_nans", + "timezones", + "peaks", + "hist", + "find_many", +] + +io = [ + "json", + "parquet", + "ipc", + "ipc_streaming", + "avro", + "csv", + "cloud", + "clipboard", +] + +optimizations = [ + "cse", + "polars/fused", + "streaming", +] + +polars_cloud = ["polars/polars_cloud"] + +# also includes simd +nightly = ["polars/nightly"] + +pymethods = [] + +all = [ + "pymethods", + "optimizations", + "io", + "operations", + "dtypes", + "meta", + "decompress", + "regex", + "sql", + "binary_encoding", + "ffi_plugin", + "polars_cloud", + # "new_streaming", +] + +# we cannot conditionally activate simd +# https://github.com/rust-lang/cargo/issues/1197 +# so we have an indirection and compile +# with --no-default-features --features=all for targets without simd +default = [ + "all", +] diff --git a/crates/polars-python/LICENSE b/crates/polars-python/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-python/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-python/README.md b/crates/polars-python/README.md new file mode 100644 index 000000000000..3a68700e34fc --- /dev/null +++ b/crates/polars-python/README.md @@ -0,0 +1,6 @@ +# polars-python + +`polars-python` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library. +It enables running Polars workloads in Python. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-python/build.rs b/crates/polars-python/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-python/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/py-polars/src/batched_csv.rs b/crates/polars-python/src/batched_csv.rs similarity index 96% rename from py-polars/src/batched_csv.rs rename to crates/polars-python/src/batched_csv.rs index 2f5159bc8402..1a688ba8ba1a 100644 --- a/py-polars/src/batched_csv.rs +++ b/crates/polars-python/src/batched_csv.rs @@ -61,7 +61,7 @@ impl PyBatchedCsv { let null_values = null_values.map(|w| w.0); let eol_char = eol_char.as_bytes()[0]; let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); let quote_char = if let Some(s) = quote_char { @@ -79,7 +79,7 @@ impl PyBatchedCsv { .iter() .map(|(name, dtype)| { let dtype = dtype.0.clone(); - Field::new(name, dtype) + Field::new((&**name).into(), dtype) }) .collect::() }); @@ -102,7 +102,7 @@ impl PyBatchedCsv { .with_projection(projection.map(Arc::new)) .with_rechunk(rechunk) .with_chunk_size(chunk_size) - .with_columns(columns.map(Arc::from)) + .with_columns(columns.map(|x| x.into_iter().map(PlSmallStr::from_string).collect())) .with_n_threads(n_threads) .with_dtype_overwrite(overwrite_dtype_slice.map(Arc::new)) .with_low_memory(low_memory) diff --git a/py-polars/src/cloud.rs b/crates/polars-python/src/cloud.rs similarity index 53% rename from py-polars/src/cloud.rs rename to crates/polars-python/src/cloud.rs index 5c8a7d01eafe..dacca675c551 100644 --- a/py-polars/src/cloud.rs +++ b/crates/polars-python/src/cloud.rs @@ -5,9 +5,9 @@ use crate::error::PyPolarsErr; use crate::PyLazyFrame; #[pyfunction] -pub fn prepare_cloud_plan(lf: PyLazyFrame, uri: String, py: Python) -> PyResult { +pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python) -> PyResult { let plan = lf.ldf.logical_plan; - let bytes = polars::prelude::prepare_cloud_plan(plan, uri).map_err(PyPolarsErr::from)?; + let bytes = polars::prelude::prepare_cloud_plan(plan).map_err(PyPolarsErr::from)?; Ok(PyBytes::new_bound(py, &bytes).to_object(py)) } diff --git a/py-polars/src/conversion/any_value.rs b/crates/polars-python/src/conversion/any_value.rs similarity index 98% rename from py-polars/src/conversion/any_value.rs rename to crates/polars-python/src/conversion/any_value.rs index 088d2e430f99..3141d02799fb 100644 --- a/py-polars/src/conversion/any_value.rs +++ b/crates/polars-python/src/conversion/any_value.rs @@ -5,7 +5,7 @@ use polars::chunked_array::object::PolarsObjectSafe; #[cfg(feature = "object")] use polars::datatypes::OwnedObject; use polars::datatypes::{DataType, Field, PlHashMap, TimeUnit}; -use polars::prelude::{AnyValue, Series}; +use polars::prelude::{AnyValue, PlSmallStr, Series}; use polars_core::export::chrono::{NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, Timelike}; use polars_core::utils::any_values_to_supertype_and_n_dtypes; use polars_core::utils::arrow::temporal_conversions::date32_to_date; @@ -289,7 +289,10 @@ pub(crate) fn py_object_to_any_value<'py>( } if ob.is_empty()? { - Ok(AnyValue::List(Series::new_empty("", &DataType::Null))) + Ok(AnyValue::List(Series::new_empty( + PlSmallStr::EMPTY, + &DataType::Null, + ))) } else if ob.is_instance_of::() | ob.is_instance_of::() { const INFER_SCHEMA_LENGTH: usize = 25; @@ -320,7 +323,7 @@ pub(crate) fn py_object_to_any_value<'py>( avs.push(av) } - let s = Series::from_any_values_and_dtype("", &avs, &dtype, strict) + let s = Series::from_any_values_and_dtype(PlSmallStr::EMPTY, &avs, &dtype, strict) .map_err(|e| { PyTypeError::new_err(format!( "{e}\n\nHint: Try setting `strict=False` to allow passing data with mixed types." @@ -348,7 +351,7 @@ pub(crate) fn py_object_to_any_value<'py>( let key = k.extract::>()?; let val = py_object_to_any_value(&v, strict)?; let dtype = val.dtype(); - keys.push(Field::new(&key, dtype)); + keys.push(Field::new(key.as_ref().into(), dtype)); vals.push(val) } Ok(AnyValue::StructOwned(Box::new((vals, keys)))) diff --git a/py-polars/src/conversion/chunked_array.rs b/crates/polars-python/src/conversion/chunked_array.rs similarity index 98% rename from py-polars/src/conversion/chunked_array.rs rename to crates/polars-python/src/conversion/chunked_array.rs index abeb4fa728e8..3a69d61f7dd1 100644 --- a/py-polars/src/conversion/chunked_array.rs +++ b/crates/polars-python/src/conversion/chunked_array.rs @@ -64,7 +64,7 @@ impl ToPyObject for Wrap<&DatetimeChunked> { let utils = UTILS.bind(py); let convert = utils.getattr(intern!(py, "to_py_datetime")).unwrap(); let time_unit = self.0.time_unit().to_ascii(); - let time_zone = time_zone.to_object(py); + let time_zone = time_zone.as_deref().to_object(py); let iter = self .0 .iter() diff --git a/py-polars/src/conversion/datetime.rs b/crates/polars-python/src/conversion/datetime.rs similarity index 100% rename from py-polars/src/conversion/datetime.rs rename to crates/polars-python/src/conversion/datetime.rs diff --git a/py-polars/src/conversion/mod.rs b/crates/polars-python/src/conversion/mod.rs similarity index 91% rename from py-polars/src/conversion/mod.rs rename to crates/polars-python/src/conversion/mod.rs index d6283597267a..fd8e97cb7adc 100644 --- a/py-polars/src/conversion/mod.rs +++ b/crates/polars-python/src/conversion/mod.rs @@ -2,7 +2,9 @@ pub(crate) mod any_value; pub(crate) mod chunked_array; mod datetime; use std::fmt::{Display, Formatter}; +use std::fs::File; use std::hash::{Hash, Hasher}; +use std::path::PathBuf; #[cfg(feature = "object")] use polars::chunked_array::object::PolarsObjectSafe; @@ -19,6 +21,8 @@ use polars_core::utils::materialize_dyn_int; use polars_lazy::prelude::*; #[cfg(feature = "parquet")] use polars_parquet::write::StatisticsOptions; +use polars_plan::plans::ScanSources; +use polars_utils::pl_str::PlSmallStr; use polars_utils::total_ord::{TotalEq, TotalHash}; use pyo3::basic::CompareOp; use pyo3::exceptions::{PyTypeError, PyValueError}; @@ -26,9 +30,9 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyDict, PyList, PySequence}; -use smartstring::alias::String as SmartString; use crate::error::PyPolarsErr; +use crate::file::{get_python_scan_source_input, PythonScanSourceInput}; #[cfg(feature = "object")] use crate::object::OBJECT_NAME; use crate::prelude::*; @@ -110,15 +114,27 @@ pub(crate) fn to_series(py: Python, s: PySeries) -> PyObject { constructor.call1((s,)).unwrap().into_py(py) } +impl<'a> FromPyObject<'a> for Wrap { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + Ok(Wrap((&*ob.extract::()?).into())) + } +} + #[cfg(feature = "csv")] impl<'a> FromPyObject<'a> for Wrap { fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { - if let Ok(s) = ob.extract::() { - Ok(Wrap(NullValues::AllColumnsSingle(s))) - } else if let Ok(s) = ob.extract::>() { - Ok(Wrap(NullValues::AllColumns(s))) - } else if let Ok(s) = ob.extract::>() { - Ok(Wrap(NullValues::Named(s))) + if let Ok(s) = ob.extract::() { + Ok(Wrap(NullValues::AllColumnsSingle((&*s).into()))) + } else if let Ok(s) = ob.extract::>() { + Ok(Wrap(NullValues::AllColumns( + s.into_iter().map(|x| (&*x).into()).collect(), + ))) + } else if let Ok(s) = ob.extract::>() { + Ok(Wrap(NullValues::Named( + s.into_iter() + .map(|(a, b)| ((&*a).into(), (&*b).into())) + .collect(), + ))) } else { Err( PyPolarsErr::Other("could not extract value from null_values argument".into()) @@ -243,7 +259,7 @@ impl ToPyObject for Wrap { DataType::Datetime(tu, tz) => { let datetime_class = pl.getattr(intern!(py, "Datetime")).unwrap(); datetime_class - .call1((tu.to_ascii(), tz.clone())) + .call1((tu.to_ascii(), tz.as_deref())) .unwrap() .into() }, @@ -267,7 +283,9 @@ impl ToPyObject for Wrap { // we should always have an initialized rev_map coming from rust let categories = rev_map.as_ref().unwrap().get_categories(); let class = pl.getattr(intern!(py, "Enum")).unwrap(); - let s = Series::from_arrow("category", categories.to_boxed()).unwrap(); + let s = + Series::from_arrow(PlSmallStr::from_static("category"), categories.to_boxed()) + .unwrap(); let series = to_series(py, s.into()); return class.call1((series,)).unwrap().into(); }, @@ -276,7 +294,7 @@ impl ToPyObject for Wrap { let field_class = pl.getattr(intern!(py, "Field")).unwrap(); let iter = fields.iter().map(|fld| { let name = fld.name().as_str(); - let dtype = Wrap(fld.data_type().clone()).to_object(py); + let dtype = Wrap(fld.dtype().clone()).to_object(py); field_class.call1((name, dtype)).unwrap() }); let fields = PyList::new_bound(py, iter); @@ -311,7 +329,7 @@ impl<'py> FromPyObject<'py> for Wrap { let dtype = ob .getattr(intern!(py, "dtype"))? .extract::>()?; - Ok(Wrap(Field::new(&name, dtype.0))) + Ok(Wrap(Field::new((&*name).into(), dtype.0))) } } @@ -385,7 +403,7 @@ impl<'py> FromPyObject<'py> for Wrap { let s = get_series(&categories.as_borrowed())?; let ca = s.str().map_err(PyPolarsErr::from)?; let categories = ca.downcast_iter().next().unwrap().clone(); - create_enum_data_type(categories) + create_enum_dtype(categories) }, "Date" => DataType::Date, "Time" => DataType::Time, @@ -393,8 +411,8 @@ impl<'py> FromPyObject<'py> for Wrap { let time_unit = ob.getattr(intern!(py, "time_unit")).unwrap(); let time_unit = time_unit.extract::>()?.0; let time_zone = ob.getattr(intern!(py, "time_zone")).unwrap(); - let time_zone = time_zone.extract()?; - DataType::Datetime(time_unit, time_zone) + let time_zone = time_zone.extract::>()?; + DataType::Datetime(time_unit, time_zone.as_deref().map(|x| x.into())) }, "Duration" => { let time_unit = ob.getattr(intern!(py, "time_unit")).unwrap(); @@ -507,13 +525,75 @@ impl<'py> FromPyObject<'py> for Wrap { let key = key.extract::()?; let val = val.extract::>()?; - Ok(Field::new(&key, val.0)) + Ok(Field::new((&*key).into(), val.0)) }) .collect::>()?, )) } } +impl<'py> FromPyObject<'py> for Wrap { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let list = ob.downcast::()?.to_owned(); + + if list.is_empty() { + return Ok(Wrap(ScanSources::default())); + } + + enum MutableSources { + Paths(Vec), + Files(Vec), + Buffers(Vec), + } + + let num_items = list.len(); + let mut iter = list + .into_iter() + .map(|val| get_python_scan_source_input(val.unbind(), false)); + + let Some(first) = iter.next() else { + return Ok(Wrap(ScanSources::default())); + }; + + let mut sources = match first? { + PythonScanSourceInput::Path(path) => { + let mut sources = Vec::with_capacity(num_items); + sources.push(path); + MutableSources::Paths(sources) + }, + PythonScanSourceInput::File(file) => { + let mut sources = Vec::with_capacity(num_items); + sources.push(file); + MutableSources::Files(sources) + }, + PythonScanSourceInput::Buffer(buffer) => { + let mut sources = Vec::with_capacity(num_items); + sources.push(buffer); + MutableSources::Buffers(sources) + }, + }; + + for source in iter { + match (&mut sources, source?) { + (MutableSources::Paths(v), PythonScanSourceInput::Path(p)) => v.push(p), + (MutableSources::Files(v), PythonScanSourceInput::File(f)) => v.push(f), + (MutableSources::Buffers(v), PythonScanSourceInput::Buffer(f)) => v.push(f), + _ => { + return Err(PyTypeError::new_err( + "Cannot combine in-memory bytes, paths and files for scan sources", + )) + }, + } + } + + Ok(Wrap(match sources { + MutableSources::Paths(i) => ScanSources::Paths(i.into()), + MutableSources::Files(i) => ScanSources::Files(i.into()), + MutableSources::Buffers(i) => ScanSources::Buffers(i.into()), + })) + } +} + impl IntoPy for Wrap<&Schema> { fn into_py(self, py: Python<'_>) -> PyObject { let dict = PyDict::new_bound(py); @@ -1173,12 +1253,15 @@ pub(crate) fn parse_parquet_compression( Ok(parsed) } -pub(crate) fn strings_to_smartstrings(container: I) -> Vec +pub(crate) fn strings_to_pl_smallstr(container: I) -> Vec where I: IntoIterator, S: AsRef, { - container.into_iter().map(|s| s.as_ref().into()).collect() + container + .into_iter() + .map(|s| PlSmallStr::from_str(s.as_ref())) + .collect() } #[derive(Debug, Copy, Clone)] diff --git a/py-polars/src/dataframe/construction.rs b/crates/polars-python/src/dataframe/construction.rs similarity index 94% rename from py-polars/src/dataframe/construction.rs rename to crates/polars-python/src/dataframe/construction.rs index d01dcf24ad34..229e1d85b2bb 100644 --- a/py-polars/src/dataframe/construction.rs +++ b/crates/polars-python/src/dataframe/construction.rs @@ -1,9 +1,12 @@ use polars::frame::row::{rows_to_schema_supertypes, rows_to_supertypes, Row}; +use polars::prelude::*; use pyo3::prelude::*; +use pyo3::types::PyDict; -use super::*; +use super::PyDataFrame; use crate::conversion::any_value::py_object_to_any_value; use crate::conversion::{vec_extract_wrapped, Wrap}; +use crate::error::PyPolarsErr; use crate::interop; #[pymethods] @@ -82,7 +85,7 @@ fn update_schema_from_rows( rows: &[Row], infer_schema_length: Option, ) -> PyResult<()> { - let schema_is_complete = schema.iter_dtypes().all(|dtype| dtype.is_known()); + let schema_is_complete = schema.iter_values().all(|dtype| dtype.is_known()); if schema_is_complete { return Ok(()); } @@ -92,7 +95,7 @@ fn update_schema_from_rows( rows_to_supertypes(rows, infer_schema_length).map_err(PyPolarsErr::from)?; let inferred_dtypes_slice = inferred_dtypes.as_slice(); - for (i, dtype) in schema.iter_dtypes_mut().enumerate() { + for (i, dtype) in schema.iter_values_mut().enumerate() { if !dtype.is_known() { *dtype = inferred_dtypes_slice.get(i).ok_or_else(|| { polars_err!(SchemaMismatch: "the number of columns in the schema does not match the data") @@ -117,7 +120,7 @@ fn resolve_schema_overrides(schema: &mut Schema, schema_overrides: Option Vec<&str> { - self.df.get_column_names() + self.df.get_column_names_str() } /// set column names pub fn set_column_names(&mut self, names: Vec) -> PyResult<()> { self.df - .set_column_names(&names) + .set_column_names(names.iter().map(|x| &**x)) .map_err(PyPolarsErr::from)?; Ok(()) } @@ -243,13 +246,16 @@ impl PyDataFrame { } pub fn select(&self, columns: Vec) -> PyResult { - let df = self.df.select(columns).map_err(PyPolarsErr::from)?; + let df = self + .df + .select(columns.iter().map(|x| &**x)) + .map_err(PyPolarsErr::from)?; Ok(PyDataFrame::new(df)) } pub fn gather(&self, indices: Wrap>) -> PyResult { let indices = indices.0; - let indices = IdxCa::from_vec("", indices); + let indices = IdxCa::from_vec("".into(), indices); let df = self.df.take(&indices).map_err(PyPolarsErr::from)?; Ok(PyDataFrame::new(df)) } @@ -319,7 +325,7 @@ impl PyDataFrame { pub fn with_row_index(&self, name: &str, offset: Option) -> PyResult { let df = self .df - .with_row_index(name, offset) + .with_row_index(name.into(), offset) .map_err(PyPolarsErr::from)?; Ok(df.into()) } @@ -331,9 +337,9 @@ impl PyDataFrame { maintain_order: bool, ) -> PyResult { let gb = if maintain_order { - self.df.group_by_stable(&by) + self.df.group_by_stable(by.iter().map(|x| &**x)) } else { - self.df.group_by(&by) + self.df.group_by(by.iter().map(|x| &**x)) } .map_err(PyPolarsErr::from)?; @@ -366,10 +372,12 @@ impl PyDataFrame { Ok(df.into()) } + #[allow(clippy::should_implement_trait)] pub fn clone(&self) -> Self { PyDataFrame::new(self.df.clone()) } + #[cfg(feature = "pivot")] pub fn unpivot( &self, on: Vec, @@ -377,12 +385,12 @@ impl PyDataFrame { value_name: Option<&str>, variable_name: Option<&str>, ) -> PyResult { - let args = UnpivotArgs { - on: strings_to_smartstrings(on), - index: strings_to_smartstrings(index), + use polars_ops::pivot::UnpivotDF; + let args = UnpivotArgsIR { + on: strings_to_pl_smallstr(on), + index: strings_to_pl_smallstr(index), value_name: value_name.map(|s| s.into()), variable_name: variable_name.map(|s| s.into()), - streamable: false, }; let df = self.df.unpivot2(args).map_err(PyPolarsErr::from)?; @@ -534,7 +542,7 @@ impl PyDataFrame { } pub fn hash_rows(&mut self, k0: u64, k1: u64, k2: u64, k3: u64) -> PyResult { - let hb = ahash::RandomState::with_seeds(k0, k1, k2, k3); + let hb = PlRandomState::with_seeds(k0, k1, k2, k3); let hash = self.df.hash_rows(Some(hb)).map_err(PyPolarsErr::from)?; Ok(hash.into_series().into()) } @@ -576,7 +584,7 @@ impl PyDataFrame { } pub fn to_struct(&self, name: &str, invalid_indices: Vec) -> PySeries { - let ca = self.df.clone().into_struct(name); + let ca = self.df.clone().into_struct(name.into()); if !invalid_indices.is_empty() { let mut validity = MutableBitmap::with_capacity(ca.len()); @@ -608,7 +616,10 @@ impl PyDataFrame { // underneath of you, so don't use this anywhere else. let mut df = std::mem::take(&mut self.df); let cols = unsafe { std::mem::take(df.get_columns_mut()) }; - let (ptr, len, cap) = cols.into_raw_parts(); + let mut md_cols = ManuallyDrop::new(cols); + let ptr = md_cols.as_mut_ptr(); + let len = md_cols.len(); + let cap = md_cols.capacity(); (ptr as usize, len, cap) } } diff --git a/py-polars/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs similarity index 97% rename from py-polars/src/dataframe/io.rs rename to crates/polars-python/src/dataframe/io.rs index e258abe036fa..dbdf91ddff09 100644 --- a/py-polars/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -1,19 +1,22 @@ use std::io::BufWriter; use std::num::NonZeroUsize; +use std::sync::Arc; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; -use polars::io::mmap::ensure_not_mapped; use polars::io::RowIndex; +use polars::prelude::*; #[cfg(feature = "parquet")] use polars_parquet::arrow::write::StatisticsOptions; +use polars_utils::mmap::ensure_not_mapped; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; -use super::*; +use super::PyDataFrame; #[cfg(feature = "parquet")] use crate::conversion::parse_parquet_compression; use crate::conversion::Wrap; +use crate::error::PyPolarsErr; use crate::file::{ get_either_file, get_file_like, get_mmap_bytes_reader, get_mmap_bytes_reader_and_path, read_if_bytesio, EitherRustPythonFile, @@ -67,7 +70,7 @@ impl PyDataFrame { let null_values = null_values.map(|w| w.0); let eol_char = eol_char.as_bytes()[0]; let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); let quote_char = quote_char.and_then(|s| s.as_bytes().first().copied()); @@ -77,7 +80,7 @@ impl PyDataFrame { .iter() .map(|(name, dtype)| { let dtype = dtype.0.clone(); - Field::new(name, dtype) + Field::new((&**name).into(), dtype) }) .collect::() }); @@ -102,7 +105,7 @@ impl PyDataFrame { .with_projection(projection.map(Arc::new)) .with_rechunk(rechunk) .with_chunk_size(chunk_size) - .with_columns(columns.map(Arc::from)) + .with_columns(columns.map(|x| x.into_iter().map(|x| x.into()).collect())) .with_n_threads(n_threads) .with_schema_overwrite(overwrite_dtype.map(Arc::new)) .with_dtype_overwrite(overwrite_dtype_slice.map(Arc::new)) @@ -150,7 +153,7 @@ impl PyDataFrame { use EitherRustPythonFile::*; let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); let result = match get_either_file(py_f, false)? { @@ -260,7 +263,7 @@ impl PyDataFrame { memory_map: bool, ) -> PyResult { let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); py_f = read_if_bytesio(py_f); @@ -293,7 +296,7 @@ impl PyDataFrame { rechunk: bool, ) -> PyResult { let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); py_f = read_if_bytesio(py_f); diff --git a/py-polars/src/dataframe/mod.rs b/crates/polars-python/src/dataframe/mod.rs similarity index 65% rename from py-polars/src/dataframe/mod.rs rename to crates/polars-python/src/dataframe/mod.rs index a9f719935c69..fbd514ab2c03 100644 --- a/py-polars/src/dataframe/mod.rs +++ b/crates/polars-python/src/dataframe/mod.rs @@ -1,14 +1,16 @@ +#[cfg(feature = "pymethods")] mod construction; +#[cfg(feature = "pymethods")] mod export; +#[cfg(feature = "pymethods")] mod general; +#[cfg(feature = "pymethods")] mod io; +#[cfg(feature = "pymethods")] mod serde; -use polars::prelude::*; -use pyo3::prelude::*; -use pyo3::types::PyDict; - -use crate::error::PyPolarsErr; +use polars::prelude::DataFrame; +use pyo3::pyclass; #[pyclass] #[repr(transparent)] diff --git a/py-polars/src/dataframe/serde.rs b/crates/polars-python/src/dataframe/serde.rs similarity index 99% rename from py-polars/src/dataframe/serde.rs rename to crates/polars-python/src/dataframe/serde.rs index 524d894786cd..5bd54d5114af 100644 --- a/py-polars/src/dataframe/serde.rs +++ b/crates/polars-python/src/dataframe/serde.rs @@ -1,6 +1,7 @@ use std::io::{BufReader, BufWriter, Cursor}; use std::ops::Deref; +use polars::prelude::*; use polars_io::mmap::ReaderBytes; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -9,7 +10,6 @@ use super::PyDataFrame; use crate::error::PyPolarsErr; use crate::exceptions::ComputeError; use crate::file::{get_file_like, get_mmap_bytes_reader}; -use crate::prelude::*; #[pymethods] impl PyDataFrame { diff --git a/py-polars/src/datatypes.rs b/crates/polars-python/src/datatypes.rs similarity index 97% rename from py-polars/src/datatypes.rs rename to crates/polars-python/src/datatypes.rs index fa06b23e48b9..a31a2301f866 100644 --- a/py-polars/src/datatypes.rs +++ b/crates/polars-python/src/datatypes.rs @@ -103,7 +103,7 @@ impl From for DataType { #[cfg(feature = "object")] PyDataType::Object => Object(OBJECT_NAME, None), PyDataType::Categorical => Categorical(None, Default::default()), - PyDataType::Enum(categories) => create_enum_data_type(categories), + PyDataType::Enum(categories) => create_enum_dtype(categories), PyDataType::Struct => Struct(vec![]), PyDataType::Decimal(p, s) => Decimal(p, Some(s)), PyDataType::Array(width) => Array(DataType::Null.into(), width), diff --git a/py-polars/src/error.rs b/crates/polars-python/src/error.rs similarity index 100% rename from py-polars/src/error.rs rename to crates/polars-python/src/error.rs diff --git a/py-polars/src/exceptions.rs b/crates/polars-python/src/exceptions.rs similarity index 100% rename from py-polars/src/exceptions.rs rename to crates/polars-python/src/exceptions.rs diff --git a/py-polars/src/expr/array.rs b/crates/polars-python/src/expr/array.rs similarity index 95% rename from py-polars/src/expr/array.rs rename to crates/polars-python/src/expr/array.rs index 01e44208e5ff..f94185d8057c 100644 --- a/py-polars/src/expr/array.rs +++ b/crates/polars-python/src/expr/array.rs @@ -1,10 +1,9 @@ -use std::borrow::Cow; - use polars::prelude::*; use polars_ops::prelude::array::ArrToStructNameGenerator; +use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; +use pyo3::pybacked::PyBackedStr; use pyo3::pymethods; -use smartstring::alias::String as SmartString; use crate::expr::PyExpr; @@ -114,7 +113,7 @@ impl PyExpr { Arc::new(move |idx: usize| { Python::with_gil(|py| { let out = lambda.call1(py, (idx,)).unwrap(); - let out: SmartString = out.extract::>(py).unwrap().into(); + let out: PlSmallStr = (&*out.extract::(py).unwrap()).into(); out }) }) as ArrToStructNameGenerator diff --git a/py-polars/src/expr/binary.rs b/crates/polars-python/src/expr/binary.rs similarity index 100% rename from py-polars/src/expr/binary.rs rename to crates/polars-python/src/expr/binary.rs diff --git a/py-polars/src/expr/categorical.rs b/crates/polars-python/src/expr/categorical.rs similarity index 100% rename from py-polars/src/expr/categorical.rs rename to crates/polars-python/src/expr/categorical.rs diff --git a/py-polars/src/expr/datetime.rs b/crates/polars-python/src/expr/datetime.rs similarity index 94% rename from py-polars/src/expr/datetime.rs rename to crates/polars-python/src/expr/datetime.rs index 5065ba676cad..69325b03a19f 100644 --- a/py-polars/src/expr/datetime.rs +++ b/crates/polars-python/src/expr/datetime.rs @@ -46,8 +46,12 @@ impl PyExpr { } #[cfg(feature = "timezones")] - fn dt_convert_time_zone(&self, time_zone: TimeZone) -> Self { - self.inner.clone().dt().convert_time_zone(time_zone).into() + fn dt_convert_time_zone(&self, time_zone: String) -> Self { + self.inner + .clone() + .dt() + .convert_time_zone(time_zone.into()) + .into() } fn dt_cast_time_unit(&self, time_unit: Wrap) -> Self { @@ -65,7 +69,7 @@ impl PyExpr { self.inner .clone() .dt() - .replace_time_zone(time_zone, ambiguous.inner, non_existent.0) + .replace_time_zone(time_zone.map(|x| x.into()), ambiguous.inner, non_existent.0) .into() } diff --git a/py-polars/src/expr/general.rs b/crates/polars-python/src/expr/general.rs similarity index 99% rename from py-polars/src/expr/general.rs rename to crates/polars-python/src/expr/general.rs index cfcfb438fda7..42779eb5bf2c 100644 --- a/py-polars/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -228,7 +228,7 @@ impl PyExpr { fn value_counts(&self, sort: bool, parallel: bool, name: String, normalize: bool) -> Self { self.inner .clone() - .value_counts(sort, parallel, name, normalize) + .value_counts(sort, parallel, name.as_str(), normalize) .into() } fn unique_counts(&self) -> Self { @@ -237,8 +237,8 @@ impl PyExpr { fn null_count(&self) -> Self { self.inner.clone().null_count().into() } - fn cast(&self, data_type: Wrap, strict: bool, wrap_numerical: bool) -> Self { - let dt = data_type.0; + fn cast(&self, dtype: Wrap, strict: bool, wrap_numerical: bool) -> Self { + let dt = dtype.0; let options = if wrap_numerical { CastOptions::Overflowing diff --git a/py-polars/src/expr/list.rs b/crates/polars-python/src/expr/list.rs similarity index 98% rename from py-polars/src/expr/list.rs rename to crates/polars-python/src/expr/list.rs index 9ab917918b83..cb179eb0e859 100644 --- a/py-polars/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -2,8 +2,8 @@ use std::borrow::Cow; use polars::prelude::*; use polars::series::ops::NullBehavior; +use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; -use smartstring::alias::String as SmartString; use crate::conversion::Wrap; use crate::PyExpr; @@ -214,7 +214,7 @@ impl PyExpr { Arc::new(move |idx: usize| { Python::with_gil(|py| { let out = lambda.call1(py, (idx,)).unwrap(); - let out: SmartString = out.extract::>(py).unwrap().into(); + let out: PlSmallStr = out.extract::>(py).unwrap().as_ref().into(); out }) }) as NameGenerator diff --git a/py-polars/src/expr/meta.rs b/crates/polars-python/src/expr/meta.rs similarity index 94% rename from py-polars/src/expr/meta.rs rename to crates/polars-python/src/expr/meta.rs index 686227154bff..25b5eabd4b61 100644 --- a/py-polars/src/expr/meta.rs +++ b/crates/polars-python/src/expr/meta.rs @@ -10,8 +10,9 @@ impl PyExpr { self.inner == other.inner } - fn meta_pop(&self) -> Vec { - self.inner.clone().meta().pop().to_pyexprs() + fn meta_pop(&self) -> PyResult> { + let exprs = self.inner.clone().meta().pop().map_err(PyPolarsErr::from)?; + Ok(exprs.to_pyexprs()) } fn meta_root_names(&self) -> Vec { diff --git a/py-polars/src/expr/mod.rs b/crates/polars-python/src/expr/mod.rs similarity index 78% rename from py-polars/src/expr/mod.rs rename to crates/polars-python/src/expr/mod.rs index 0206f74ca0aa..85d44fefbf98 100644 --- a/py-polars/src/expr/mod.rs +++ b/crates/polars-python/src/expr/mod.rs @@ -1,21 +1,32 @@ +#[cfg(feature = "pymethods")] mod array; +#[cfg(feature = "pymethods")] mod binary; +#[cfg(feature = "pymethods")] mod categorical; +#[cfg(feature = "pymethods")] mod datetime; +#[cfg(feature = "pymethods")] mod general; +#[cfg(feature = "pymethods")] mod list; -#[cfg(feature = "meta")] +#[cfg(all(feature = "meta", feature = "pymethods"))] mod meta; +#[cfg(feature = "pymethods")] mod name; +#[cfg(feature = "pymethods")] mod rolling; +#[cfg(feature = "pymethods")] mod serde; +#[cfg(feature = "pymethods")] mod string; +#[cfg(feature = "pymethods")] mod r#struct; use std::mem::ManuallyDrop; use polars::lazy::dsl::Expr; -use pyo3::prelude::*; +use pyo3::pyclass; #[pyclass] #[repr(transparent)] diff --git a/py-polars/src/expr/name.rs b/crates/polars-python/src/expr/name.rs similarity index 87% rename from py-polars/src/expr/name.rs rename to crates/polars-python/src/expr/name.rs index 6bbda4a6668a..e5be57ac9458 100644 --- a/py-polars/src/expr/name.rs +++ b/crates/polars-python/src/expr/name.rs @@ -1,8 +1,9 @@ use std::borrow::Cow; use polars::prelude::*; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; -use smartstring::alias::String as SmartString; use crate::PyExpr; @@ -17,9 +18,9 @@ impl PyExpr { .clone() .name() .map(move |name| { - let out = Python::with_gil(|py| lambda.call1(py, (name,))); + let out = Python::with_gil(|py| lambda.call1(py, (name.as_str(),))); match out { - Ok(out) => Ok(out.to_string()), + Ok(out) => Ok(format_pl_smallstr!("{}", out)), Err(e) => Err(PolarsError::ComputeError( format!("Python function in 'name.map' produced an error: {e}.").into(), )), @@ -48,7 +49,7 @@ impl PyExpr { let name_mapper = Arc::new(move |name: &str| { Python::with_gil(|py| { let out = name_mapper.call1(py, (name,)).unwrap(); - let out: SmartString = out.extract::>(py).unwrap().into(); + let out: PlSmallStr = out.extract::>(py).unwrap().as_ref().into(); out }) }) as FieldsNameMapper; diff --git a/py-polars/src/expr/rolling.rs b/crates/polars-python/src/expr/rolling.rs similarity index 81% rename from py-polars/src/expr/rolling.rs rename to crates/polars-python/src/expr/rolling.rs index b854cb4bd89b..712accafb839 100644 --- a/py-polars/src/expr/rolling.rs +++ b/crates/polars-python/src/expr/rolling.rs @@ -363,81 +363,105 @@ impl PyExpr { UInt8 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(UInt8Chunked::from_slice("", &[v as u8]).into_series()) + Ok(UInt8Chunked::from_slice(PlSmallStr::EMPTY, &[v as u8]) + .into_series()) } else { - obj.extract::(py) - .map(|v| UInt8Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + UInt8Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, UInt16 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(UInt16Chunked::from_slice("", &[v as u16]).into_series()) + Ok(UInt16Chunked::from_slice(PlSmallStr::EMPTY, &[v as u16]) + .into_series()) } else { - obj.extract::(py) - .map(|v| UInt16Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + UInt16Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, UInt32 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(UInt32Chunked::from_slice("", &[v as u32]).into_series()) + Ok(UInt32Chunked::from_slice(PlSmallStr::EMPTY, &[v as u32]) + .into_series()) } else { - obj.extract::(py) - .map(|v| UInt32Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + UInt32Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, UInt64 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(UInt64Chunked::from_slice("", &[v as u64]).into_series()) + Ok(UInt64Chunked::from_slice(PlSmallStr::EMPTY, &[v as u64]) + .into_series()) } else { - obj.extract::(py) - .map(|v| UInt64Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + UInt64Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, Int8 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(Int8Chunked::from_slice("", &[v as i8]).into_series()) + Ok(Int8Chunked::from_slice(PlSmallStr::EMPTY, &[v as i8]) + .into_series()) } else { - obj.extract::(py) - .map(|v| Int8Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + Int8Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, Int16 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(Int16Chunked::from_slice("", &[v as i16]).into_series()) + Ok(Int16Chunked::from_slice(PlSmallStr::EMPTY, &[v as i16]) + .into_series()) } else { - obj.extract::(py) - .map(|v| Int16Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + Int16Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, Int32 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(Int32Chunked::from_slice("", &[v as i32]).into_series()) + Ok(Int32Chunked::from_slice(PlSmallStr::EMPTY, &[v as i32]) + .into_series()) } else { - obj.extract::(py) - .map(|v| Int32Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + Int32Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, Int64 => { if is_float { let v = obj.extract::(py).unwrap(); - Ok(Int64Chunked::from_slice("", &[v as i64]).into_series()) + Ok(Int64Chunked::from_slice(PlSmallStr::EMPTY, &[v as i64]) + .into_series()) } else { - obj.extract::(py) - .map(|v| Int64Chunked::from_slice("", &[v]).into_series()) + obj.extract::(py).map(|v| { + Int64Chunked::from_slice(PlSmallStr::EMPTY, &[v]) + .into_series() + }) } }, - Float32 => obj - .extract::(py) - .map(|v| Float32Chunked::from_slice("", &[v]).into_series()), - Float64 => obj - .extract::(py) - .map(|v| Float64Chunked::from_slice("", &[v]).into_series()), + Float32 => obj.extract::(py).map(|v| { + Float32Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series() + }), + Float64 => obj.extract::(py).map(|v| { + Float64Chunked::from_slice(PlSmallStr::EMPTY, &[v]).into_series() + }), dt => panic!("{dt:?} not implemented"), }; diff --git a/py-polars/src/expr/serde.rs b/crates/polars-python/src/expr/serde.rs similarity index 100% rename from py-polars/src/expr/serde.rs rename to crates/polars-python/src/expr/serde.rs diff --git a/py-polars/src/expr/string.rs b/crates/polars-python/src/expr/string.rs similarity index 97% rename from py-polars/src/expr/string.rs rename to crates/polars-python/src/expr/string.rs index 55f2aa71140b..e238e412dc02 100644 --- a/py-polars/src/expr/string.rs +++ b/crates/polars-python/src/expr/string.rs @@ -17,6 +17,8 @@ impl PyExpr { #[pyo3(signature = (format, strict, exact, cache))] fn str_to_date(&self, format: Option, strict: bool, exact: bool, cache: bool) -> Self { + let format = format.map(|x| x.into()); + let options = StrptimeOptions { format, strict, @@ -31,12 +33,15 @@ impl PyExpr { &self, format: Option, time_unit: Option>, - time_zone: Option, + time_zone: Option>, strict: bool, exact: bool, cache: bool, ambiguous: Self, ) -> Self { + let format = format.map(|x| x.into()); + let time_zone = time_zone.map(|x| x.0); + let options = StrptimeOptions { format, strict, @@ -57,6 +62,8 @@ impl PyExpr { #[pyo3(signature = (format, strict, cache))] fn str_to_time(&self, format: Option, strict: bool, cache: bool) -> Self { + let format = format.map(|x| x.into()); + let options = StrptimeOptions { format, strict, diff --git a/py-polars/src/expr/struct.rs b/crates/polars-python/src/expr/struct.rs similarity index 100% rename from py-polars/src/expr/struct.rs rename to crates/polars-python/src/expr/struct.rs diff --git a/py-polars/src/file.rs b/crates/polars-python/src/file.rs similarity index 65% rename from py-polars/src/file.rs rename to crates/polars-python/src/file.rs index 93494eeea179..33d084c5130c 100644 --- a/py-polars/src/file.rs +++ b/crates/polars-python/src/file.rs @@ -1,5 +1,7 @@ use std::borrow::Cow; -use std::fs::{self, File}; +#[cfg(target_family = "unix")] +use std::fs; +use std::fs::File; use std::io; use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom, Write}; #[cfg(target_family = "unix")] @@ -10,7 +12,7 @@ use polars::io::mmap::MmapBytesReader; use polars_error::{polars_err, polars_warn}; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyString}; +use pyo3::types::{PyBytes, PyString, PyStringMethods}; use crate::error::PyPolarsErr; use crate::prelude::resolve_homedir; @@ -29,6 +31,10 @@ impl PyFileLikeObject { PyFileLikeObject { inner: object } } + pub fn as_bytes(&self) -> bytes::Bytes { + self.as_file_buffer().into_inner().into() + } + pub fn as_buffer(&self) -> std::io::Cursor> { let data = self.as_file_buffer().into_inner(); std::io::Cursor::new(data) @@ -41,11 +47,19 @@ impl PyFileLikeObject { .call_method_bound(py, "read", (), None) .expect("no read method found"); - let bytes: &Bound<'_, PyBytes> = bytes - .downcast_bound(py) - .expect("Expecting to be able to downcast into bytes from read result."); + if let Ok(bytes) = bytes.downcast_bound::(py) { + return bytes.as_bytes().to_vec(); + } - bytes.as_bytes().to_vec() + if let Ok(bytes) = bytes.downcast_bound::(py) { + return bytes + .to_cow() + .expect("PyString is not valid UTF-8") + .into_owned() + .into_bytes(); + } + + panic!("Expecting to be able to downcast into bytes from read result."); }); Cursor::new(buf) @@ -189,7 +203,127 @@ impl EitherRustPythonFile { } } -fn get_either_file_and_path( +pub enum PythonScanSourceInput { + Buffer(bytes::Bytes), + Path(PathBuf), + File(File), +} + +pub fn get_python_scan_source_input( + py_f: PyObject, + write: bool, +) -> PyResult { + Python::with_gil(|py| { + let py_f = py_f.into_bound(py); + + // If the pyobject is a `bytes` class + if let Ok(bytes) = py_f.downcast::() { + return Ok(PythonScanSourceInput::Buffer( + bytes::Bytes::copy_from_slice(bytes.as_bytes()), + )); + } + + if let Ok(s) = py_f.extract::>() { + let file_path = std::path::Path::new(&*s); + let file_path = resolve_homedir(file_path); + Ok(PythonScanSourceInput::Path(file_path)) + } else { + let io = py.import_bound("io").unwrap(); + let is_utf8_encoding = |py_f: &Bound| -> PyResult { + let encoding = py_f.getattr("encoding")?; + let encoding = encoding.extract::>()?; + Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8")) + }; + + #[cfg(target_family = "unix")] + if let Some(fd) = (py_f.is_exact_instance(&io.getattr("FileIO").unwrap()) + || (py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap()) + || py_f.is_exact_instance(&io.getattr("BufferedWriter").unwrap()) + || py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap()) + || py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap()) + || (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap()) + && is_utf8_encoding(&py_f)?)) + && if write { + // invalidate read buffer + py_f.call_method0("flush").is_ok() + } else { + // flush write buffer + py_f.call_method1("seek", (0, 1)).is_ok() + }) + .then(|| { + py_f.getattr("fileno") + .and_then(|fileno| fileno.call0()) + .and_then(|fileno| fileno.extract::()) + .ok() + }) + .flatten() + .map(|fileno| unsafe { + // `File::from_raw_fd()` takes the ownership of the file descriptor. + // When the File is dropped, it closes the file descriptor. + // This is undesired - the Python file object will become invalid. + // Therefore, we duplicate the file descriptor here. + // Closing the duplicated file descriptor will not close + // the original file descriptor; + // and the status, e.g. stream position, is still shared with + // the original file descriptor. + // We use `F_DUPFD_CLOEXEC` here instead of `dup()` + // because it also sets the `O_CLOEXEC` flag on the duplicated file descriptor, + // which `dup()` clears. + // `open()` in both Rust and Python automatically set `O_CLOEXEC` flag; + // it prevents leaking file descriptors across processes, + // and we want to be consistent with them. + // `F_DUPFD_CLOEXEC` is defined in POSIX.1-2008 + // and is present on all alive UNIX(-like) systems. + libc::fcntl(fileno, libc::F_DUPFD_CLOEXEC, 0) + }) + .filter(|fileno| *fileno != -1) + .map(|fileno| fileno as RawFd) + { + return Ok(PythonScanSourceInput::File(unsafe { + File::from_raw_fd(fd) + })); + } + + // BytesIO / StringIO is relatively fast, and some code relies on it. + if !py_f.is_exact_instance(&io.getattr("BytesIO").unwrap()) + && !py_f.is_exact_instance(&io.getattr("StringIO").unwrap()) + { + polars_warn!("Polars found a filename. \ + Ensure you pass a path to the file instead of a python file object when possible for best \ + performance."); + } + // Unwrap TextIOWrapper + // Allow subclasses to allow things like pytest.capture.CaptureIO + let py_f = if py_f + .is_instance(&io.getattr("TextIOWrapper").unwrap()) + .unwrap_or_default() + { + if !is_utf8_encoding(&py_f)? { + return Err(PyPolarsErr::from( + polars_err!(InvalidOperation: "file encoding is not UTF-8"), + ) + .into()); + } + // XXX: we have to clear buffer here. + // Is there a better solution? + if write { + py_f.call_method0("flush")?; + } else { + py_f.call_method1("seek", (0, 1))?; + } + py_f.getattr("buffer")? + } else { + py_f + }; + PyFileLikeObject::ensure_requirements(&py_f, !write, write, !write)?; + Ok(PythonScanSourceInput::Buffer( + PyFileLikeObject::new(py_f.to_object(py)).as_bytes(), + )) + } + }) +} + +fn get_either_buffer_or_path( py_f: PyObject, write: bool, ) -> PyResult<(EitherRustPythonFile, Option)> { @@ -263,8 +397,10 @@ fn get_either_file_and_path( )); } - // BytesIO is relatively fast, and some code relies on it. - if !py_f.is_exact_instance(&io.getattr("BytesIO").unwrap()) { + // BytesIO / StringIO is relatively fast, and some code relies on it. + if !py_f.is_exact_instance(&io.getattr("BytesIO").unwrap()) + && !py_f.is_exact_instance(&io.getattr("StringIO").unwrap()) + { polars_warn!("Polars found a filename. \ Ensure you pass a path to the file instead of a python file object when possible for best \ performance."); @@ -303,7 +439,7 @@ fn get_either_file_and_path( /// # Arguments /// * `write` - open for writing; will truncate existing file and create new file if not. pub fn get_either_file(py_f: PyObject, write: bool) -> PyResult { - Ok(get_either_file_and_path(py_f, write)?.0) + Ok(get_either_buffer_or_path(py_f, write)?.0) } pub fn get_file_like(f: PyObject, truncate: bool) -> PyResult> { @@ -340,7 +476,7 @@ pub fn get_mmap_bytes_reader_and_path<'a>( } // string so read file else { - match get_either_file_and_path(py_f.to_object(py_f.py()), false)? { + match get_either_buffer_or_path(py_f.to_object(py_f.py()), false)? { (EitherRustPythonFile::Rust(f), path) => Ok((Box::new(f), path)), (EitherRustPythonFile::Py(f), path) => Ok((Box::new(f), path)), } diff --git a/py-polars/src/functions/aggregation.rs b/crates/polars-python/src/functions/aggregation.rs similarity index 100% rename from py-polars/src/functions/aggregation.rs rename to crates/polars-python/src/functions/aggregation.rs diff --git a/py-polars/src/functions/business.rs b/crates/polars-python/src/functions/business.rs similarity index 100% rename from py-polars/src/functions/business.rs rename to crates/polars-python/src/functions/business.rs diff --git a/py-polars/src/functions/eager.rs b/crates/polars-python/src/functions/eager.rs similarity index 100% rename from py-polars/src/functions/eager.rs rename to crates/polars-python/src/functions/eager.rs diff --git a/py-polars/src/functions/io.rs b/crates/polars-python/src/functions/io.rs similarity index 82% rename from py-polars/src/functions/io.rs rename to crates/polars-python/src/functions/io.rs index f6da57e5fc3d..fe681ce648e2 100644 --- a/py-polars/src/functions/io.rs +++ b/crates/polars-python/src/functions/io.rs @@ -1,8 +1,9 @@ use std::io::BufReader; -use polars_core::datatypes::create_enum_data_type; +#[cfg(any(feature = "ipc", feature = "parquet"))] +use polars::prelude::ArrowSchema; +use polars_core::datatypes::create_enum_dtype; use polars_core::export::arrow::array::Utf8ViewArray; -use polars_core::export::arrow::datatypes::Field; use polars_core::prelude::{DTYPE_ENUM_KEY, DTYPE_ENUM_VALUE}; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -24,7 +25,7 @@ pub fn read_ipc_schema(py: Python, py_f: PyObject) -> PyResult { }; let dict = PyDict::new_bound(py); - fields_to_pydict(&metadata.schema.fields, &dict, py)?; + fields_to_pydict(&metadata.schema, &dict, py)?; Ok(dict.to_object(py)) } @@ -42,21 +43,21 @@ pub fn read_parquet_schema(py: Python, py_f: PyObject) -> PyResult { let arrow_schema = infer_schema(&metadata).map_err(PyPolarsErr::from)?; let dict = PyDict::new_bound(py); - fields_to_pydict(&arrow_schema.fields, &dict, py)?; + fields_to_pydict(&arrow_schema, &dict, py)?; Ok(dict.to_object(py)) } #[cfg(any(feature = "ipc", feature = "parquet"))] -fn fields_to_pydict(fields: &Vec, dict: &Bound<'_, PyDict>, py: Python) -> PyResult<()> { - for field in fields { +fn fields_to_pydict(schema: &ArrowSchema, dict: &Bound<'_, PyDict>, py: Python) -> PyResult<()> { + for field in schema.iter_values() { let dt = if field.metadata.get(DTYPE_ENUM_KEY) == Some(&DTYPE_ENUM_VALUE.into()) { - Wrap(create_enum_data_type(Utf8ViewArray::new_empty( + Wrap(create_enum_dtype(Utf8ViewArray::new_empty( ArrowDataType::LargeUtf8, ))) } else { - Wrap((&field.data_type).into()) + Wrap((&field.dtype).into()) }; - dict.set_item(&field.name, dt.to_object(py))?; + dict.set_item(field.name.as_str(), dt.to_object(py))?; } Ok(()) } diff --git a/py-polars/src/functions/lazy.rs b/crates/polars-python/src/functions/lazy.rs similarity index 88% rename from py-polars/src/functions/lazy.rs rename to crates/polars-python/src/functions/lazy.rs index 49325e617170..108aaf2121b1 100644 --- a/py-polars/src/functions/lazy.rs +++ b/crates/polars-python/src/functions/lazy.rs @@ -97,13 +97,7 @@ pub fn as_struct(exprs: Vec) -> PyResult { #[pyfunction] pub fn field(names: Vec) -> PyExpr { - dsl::Expr::Field( - names - .into_iter() - .map(|name| Arc::from(name.as_str())) - .collect(), - ) - .into() + dsl::Expr::Field(names.into_iter().map(|x| x.into()).collect()).into() } #[pyfunction] @@ -254,7 +248,7 @@ pub fn datetime( second: Option, microsecond: Option, time_unit: Wrap, - time_zone: Option, + time_zone: Option>, ambiguous: Option, ) -> PyExpr { let year = year.inner; @@ -265,6 +259,7 @@ pub fn datetime( .map(|e| e.inner) .unwrap_or(dsl::lit(String::from("raise"))); let time_unit = time_unit.0; + let time_zone = time_zone.map(|x| x.0); let args = DatetimeArgs { year, month, @@ -435,21 +430,37 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult { Ok(dsl::lit(Null {}).into()) } else if let Ok(value) = value.downcast::() { Ok(dsl::lit(value.as_bytes()).into()) - } else if value.get_type().qualname().unwrap() == "Decimal" { + } else if matches!( + value.get_type().qualname().unwrap().as_str(), + "date" | "datetime" | "time" | "timedelta" | "Decimal" + ) { let av = py_object_to_any_value(value, true)?; Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()) - } else if allow_object { - let s = Python::with_gil(|py| { - PySeries::new_object(py, "", vec![ObjectValue::from(value.into_py(py))], false).series - }); - Ok(dsl::lit(s).into()) } else { - Err(PyTypeError::new_err(format!( - "cannot create expression literal for value of type {}: {}\ - \n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.", - value.get_type().qualname()?, - value.repr()? - ))) + Python::with_gil(|py| { + // One final attempt before erroring. Do we have a date/datetime subclass? + // E.g. pd.Timestamp, or Freezegun. + let datetime_module = PyModule::import_bound(py, "datetime")?; + let datetime_class = datetime_module.getattr("datetime")?; + let date_class = datetime_module.getattr("date")?; + if value.is_instance(&datetime_class)? || value.is_instance(&date_class)? { + let av = py_object_to_any_value(value, true)?; + Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()) + } else if allow_object { + let s = Python::with_gil(|py| { + PySeries::new_object(py, "", vec![ObjectValue::from(value.into_py(py))], false) + .series + }); + Ok(dsl::lit(s).into()) + } else { + Err(PyTypeError::new_err(format!( + "cannot create expression literal for value of type {}: {}\ + \n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.", + value.get_type().qualname()?, + value.repr()? + ))) + } + }) } } diff --git a/py-polars/src/functions/meta.rs b/crates/polars-python/src/functions/meta.rs similarity index 100% rename from py-polars/src/functions/meta.rs rename to crates/polars-python/src/functions/meta.rs diff --git a/py-polars/src/functions/misc.rs b/crates/polars-python/src/functions/misc.rs similarity index 87% rename from py-polars/src/functions/misc.rs rename to crates/polars-python/src/functions/misc.rs index 5f372068e8e4..2ade770d728e 100644 --- a/py-polars/src/functions/misc.rs +++ b/crates/polars-python/src/functions/misc.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use polars_plan::prelude::*; use pyo3::prelude::*; @@ -52,9 +50,9 @@ pub fn register_plugin_function( Ok(Expr::Function { input: args.to_exprs(), function: FunctionExpr::FfiPlugin { - lib: Arc::from(plugin_path), - symbol: Arc::from(function_name), - kwargs: Arc::from(kwargs), + lib: plugin_path.into(), + symbol: function_name.into(), + kwargs: kwargs.into(), }, options: FunctionOptions { collect_groups, @@ -65,3 +63,8 @@ pub fn register_plugin_function( } .into()) } + +#[pyfunction] +pub fn __register_startup_deps() { + crate::on_startup::register_startup_deps() +} diff --git a/py-polars/src/functions/mod.rs b/crates/polars-python/src/functions/mod.rs similarity index 100% rename from py-polars/src/functions/mod.rs rename to crates/polars-python/src/functions/mod.rs diff --git a/py-polars/src/functions/random.rs b/crates/polars-python/src/functions/random.rs similarity index 100% rename from py-polars/src/functions/random.rs rename to crates/polars-python/src/functions/random.rs diff --git a/py-polars/src/functions/range.rs b/crates/polars-python/src/functions/range.rs similarity index 93% rename from py-polars/src/functions/range.rs rename to crates/polars-python/src/functions/range.rs index ce725dda4ca4..b07522650de3 100644 --- a/py-polars/src/functions/range.rs +++ b/crates/polars-python/src/functions/range.rs @@ -34,7 +34,7 @@ pub fn eager_int_range( let start_v: <$T as PolarsNumericType>::Native = lower.extract()?; let end_v: <$T as PolarsNumericType>::Native = upper.extract()?; let step: i64 = step.extract()?; - new_int_range::<$T>(start_v, end_v, step, "literal") + new_int_range::<$T>(start_v, end_v, step, PlSmallStr::from_static("literal")) }); let s = ret.map_err(PyPolarsErr::from)?; @@ -100,13 +100,14 @@ pub fn datetime_range( every: &str, closed: Wrap, time_unit: Option>, - time_zone: Option, + time_zone: Option>, ) -> PyExpr { let start = start.inner; let end = end.inner; let every = Duration::parse(every); let closed = closed.0; let time_unit = time_unit.map(|x| x.0); + let time_zone = time_zone.map(|x| x.0); dsl::datetime_range(start, end, every, closed, time_unit, time_zone).into() } @@ -117,13 +118,14 @@ pub fn datetime_ranges( every: &str, closed: Wrap, time_unit: Option>, - time_zone: Option, + time_zone: Option>, ) -> PyExpr { let start = start.inner; let end = end.inner; let every = Duration::parse(every); let closed = closed.0; let time_unit = time_unit.map(|x| x.0); + let time_zone = time_zone.map(|x| x.0); dsl::datetime_ranges(start, end, every, closed, time_unit, time_zone).into() } diff --git a/py-polars/src/functions/string_cache.rs b/crates/polars-python/src/functions/string_cache.rs similarity index 100% rename from py-polars/src/functions/string_cache.rs rename to crates/polars-python/src/functions/string_cache.rs diff --git a/py-polars/src/functions/whenthen.rs b/crates/polars-python/src/functions/whenthen.rs similarity index 100% rename from py-polars/src/functions/whenthen.rs rename to crates/polars-python/src/functions/whenthen.rs diff --git a/py-polars/src/gil_once_cell.rs b/crates/polars-python/src/gil_once_cell.rs similarity index 97% rename from py-polars/src/gil_once_cell.rs rename to crates/polars-python/src/gil_once_cell.rs index 5608283a214c..17a79334560c 100644 --- a/py-polars/src/gil_once_cell.rs +++ b/crates/polars-python/src/gil_once_cell.rs @@ -14,6 +14,7 @@ unsafe impl Send for GILOnceCell {} impl GILOnceCell { /// Create a `GILOnceCell` which does not yet contain a value. + #[allow(clippy::new_without_default)] pub const fn new() -> Self { Self(UnsafeCell::new(None)) } diff --git a/py-polars/src/interop/arrow/mod.rs b/crates/polars-python/src/interop/arrow/mod.rs similarity index 100% rename from py-polars/src/interop/arrow/mod.rs rename to crates/polars-python/src/interop/arrow/mod.rs diff --git a/py-polars/src/interop/arrow/to_py.rs b/crates/polars-python/src/interop/arrow/to_py.rs similarity index 89% rename from py-polars/src/interop/arrow/to_py.rs rename to crates/polars-python/src/interop/arrow/to_py.rs index 2581a52f34ce..de6c07ef31c9 100644 --- a/py-polars/src/interop/arrow/to_py.rs +++ b/crates/polars-python/src/interop/arrow/to_py.rs @@ -5,7 +5,7 @@ use arrow::ffi; use arrow::record_batch::RecordBatch; use polars::datatypes::CompatLevel; use polars::frame::DataFrame; -use polars::prelude::{ArrayRef, ArrowField}; +use polars::prelude::{ArrayRef, ArrowField, PlSmallStr, SchemaExt}; use polars::series::Series; use polars_core::utils::arrow; use polars_error::PolarsResult; @@ -20,8 +20,8 @@ pub(crate) fn to_py_array( pyarrow: &Bound, ) -> PyResult { let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( - "", - array.data_type().clone(), + PlSmallStr::EMPTY, + array.dtype().clone(), true, ))); let array = Box::new(ffi::export_array_to_c(array)); @@ -84,7 +84,7 @@ pub(crate) fn dataframe_to_stream<'py>( pub struct DataFrameStreamIterator { columns: Vec, - data_type: ArrowDataType, + dtype: ArrowDataType, idx: usize, n_chunks: usize, } @@ -92,18 +92,18 @@ pub struct DataFrameStreamIterator { impl DataFrameStreamIterator { fn new(df: &DataFrame) -> Self { let schema = df.schema().to_arrow(CompatLevel::newest()); - let data_type = ArrowDataType::Struct(schema.fields); + let dtype = ArrowDataType::Struct(schema.into_iter_values().collect()); Self { columns: df.get_columns().to_vec(), - data_type, + dtype, idx: 0, n_chunks: df.n_chunks(), } } fn field(&self) -> ArrowField { - ArrowField::new("", self.data_type.clone(), false) + ArrowField::new(PlSmallStr::EMPTY, self.dtype.clone(), false) } } @@ -122,7 +122,7 @@ impl Iterator for DataFrameStreamIterator { .collect(); self.idx += 1; - let array = arrow::array::StructArray::new(self.data_type.clone(), batch_cols, None); + let array = arrow::array::StructArray::new(self.dtype.clone(), batch_cols, None); Some(Ok(Box::new(array))) } } diff --git a/py-polars/src/interop/arrow/to_rust.rs b/crates/polars-python/src/interop/arrow/to_rust.rs similarity index 87% rename from py-polars/src/interop/arrow/to_rust.rs rename to crates/polars-python/src/interop/arrow/to_rust.rs index 411a683ad778..8d76f53b243a 100644 --- a/py-polars/src/interop/arrow/to_rust.rs +++ b/crates/polars-python/src/interop/arrow/to_rust.rs @@ -41,7 +41,7 @@ pub fn array_to_rust(obj: &Bound) -> PyResult { unsafe { let field = ffi::import_field_from_c(schema.as_ref()).map_err(PyPolarsErr::from)?; - let array = ffi::import_array_from_c(*array, field.data_type).map_err(PyPolarsErr::from)?; + let array = ffi::import_array_from_c(*array, field.dtype).map_err(PyPolarsErr::from)?; Ok(array) } } @@ -51,7 +51,12 @@ pub fn to_rust_df(rb: &[Bound]) -> PyResult { .first() .ok_or_else(|| PyPolarsErr::Other("empty table".into()))? .getattr("schema")?; - let names = schema.getattr("names")?.extract::>()?; + let names = schema + .getattr("names")? + .extract::>()? + .into_iter() + .map(PlSmallStr::from_string) + .collect::>(); let dfs = rb .iter() @@ -63,7 +68,7 @@ pub fn to_rust_df(rb: &[Bound]) -> PyResult { let array = rb.call_method1("column", (i,))?; let arr = array_to_rust(&array)?; run_parallel |= matches!( - arr.data_type(), + arr.dtype(), ArrowDataType::Utf8 | ArrowDataType::Dictionary(_, _, _) ); Ok(arr) @@ -79,7 +84,7 @@ pub fn to_rust_df(rb: &[Bound]) -> PyResult { .into_par_iter() .enumerate() .map(|(i, arr)| { - let s = Series::try_from((names[i].as_str(), arr)) + let s = Series::try_from((names[i].clone(), arr)) .map_err(PyPolarsErr::from)?; Ok(s) }) @@ -90,8 +95,8 @@ pub fn to_rust_df(rb: &[Bound]) -> PyResult { .into_iter() .enumerate() .map(|(i, arr)| { - let s = Series::try_from((names[i].as_str(), arr)) - .map_err(PyPolarsErr::from)?; + let s = + Series::try_from((names[i].clone(), arr)).map_err(PyPolarsErr::from)?; Ok(s) }) .collect::>>() diff --git a/py-polars/src/interop/mod.rs b/crates/polars-python/src/interop/mod.rs similarity index 100% rename from py-polars/src/interop/mod.rs rename to crates/polars-python/src/interop/mod.rs diff --git a/py-polars/src/interop/numpy/mod.rs b/crates/polars-python/src/interop/numpy/mod.rs similarity index 100% rename from py-polars/src/interop/numpy/mod.rs rename to crates/polars-python/src/interop/numpy/mod.rs diff --git a/py-polars/src/interop/numpy/to_numpy_df.rs b/crates/polars-python/src/interop/numpy/to_numpy_df.rs similarity index 100% rename from py-polars/src/interop/numpy/to_numpy_df.rs rename to crates/polars-python/src/interop/numpy/to_numpy_df.rs diff --git a/py-polars/src/interop/numpy/to_numpy_series.rs b/crates/polars-python/src/interop/numpy/to_numpy_series.rs similarity index 100% rename from py-polars/src/interop/numpy/to_numpy_series.rs rename to crates/polars-python/src/interop/numpy/to_numpy_series.rs diff --git a/py-polars/src/interop/numpy/utils.rs b/crates/polars-python/src/interop/numpy/utils.rs similarity index 100% rename from py-polars/src/interop/numpy/utils.rs rename to crates/polars-python/src/interop/numpy/utils.rs diff --git a/py-polars/src/lazyframe/exitable.rs b/crates/polars-python/src/lazyframe/exitable.rs similarity index 88% rename from py-polars/src/lazyframe/exitable.rs rename to crates/polars-python/src/lazyframe/exitable.rs index 20e7ac5cd20f..f073689e89db 100644 --- a/py-polars/src/lazyframe/exitable.rs +++ b/crates/polars-python/src/lazyframe/exitable.rs @@ -1,4 +1,9 @@ -use super::*; +use polars::prelude::*; +use pyo3::prelude::*; + +use super::PyLazyFrame; +use crate::error::PyPolarsErr; +use crate::PyDataFrame; #[pymethods] impl PyLazyFrame { diff --git a/py-polars/src/lazyframe/mod.rs b/crates/polars-python/src/lazyframe/general.rs similarity index 82% rename from py-polars/src/lazyframe/mod.rs rename to crates/polars-python/src/lazyframe/general.rs index 96d28b3e78f0..86bcd3c2566b 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -1,23 +1,18 @@ -mod exitable; -mod visit; -pub(crate) mod visitor; use std::collections::HashMap; use std::num::NonZeroUsize; use std::path::PathBuf; -mod serde; -pub use exitable::PyInProcessQuery; use polars::io::{HiveOptions, RowIndex}; use polars::time::*; use polars_core::prelude::*; #[cfg(feature = "parquet")] use polars_parquet::arrow::write::StatisticsOptions; -use pyo3::exceptions::PyValueError; +use polars_plan::plans::ScanSources; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyDict, PyList}; -pub(crate) use visit::PyExprIR; +use super::PyLazyFrame; use crate::error::PyPolarsErr; use crate::expr::ToExprs; use crate::interop::arrow::to_rust::pyarrow_schema_to_rust; @@ -25,17 +20,17 @@ use crate::lazyframe::visit::NodeTraverser; use crate::prelude::*; use crate::{PyDataFrame, PyExpr, PyLazyGroupBy}; -#[pyclass] -#[repr(transparent)] -#[derive(Clone)] -pub struct PyLazyFrame { - pub ldf: LazyFrame, -} - -impl From for PyLazyFrame { - fn from(ldf: LazyFrame) -> Self { - PyLazyFrame { ldf } - } +fn pyobject_to_first_path_and_scan_sources( + obj: PyObject, +) -> PyResult<(Option, ScanSources)> { + use crate::file::{get_python_scan_source_input, PythonScanSourceInput}; + Ok(match get_python_scan_source_input(obj, false)? { + PythonScanSourceInput::Path(path) => { + (Some(path.clone()), ScanSources::Paths([path].into())) + }, + PythonScanSourceInput::File(file) => (None, ScanSources::Files([file].into())), + PythonScanSourceInput::Buffer(buff) => (None, ScanSources::Buffers([buff].into())), + }) } #[pymethods] @@ -45,12 +40,12 @@ impl PyLazyFrame { #[cfg(feature = "json")] #[allow(clippy::too_many_arguments)] #[pyo3(signature = ( - path, paths, infer_schema_length, schema, schema_overrides, batch_size, n_rows, low_memory, rechunk, + source, sources, infer_schema_length, schema, schema_overrides, batch_size, n_rows, low_memory, rechunk, row_index, ignore_errors, include_file_paths, cloud_options, retries, file_cache_ttl ))] fn new_from_ndjson( - path: Option, - paths: Vec, + source: Option, + sources: Wrap, infer_schema_length: Option, schema: Option>, schema_overrides: Option>, @@ -66,41 +61,31 @@ impl PyLazyFrame { file_cache_ttl: Option, ) -> PyResult { let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); - #[cfg(feature = "cloud")] - let cloud_options = { - let first_path = if let Some(path) = &path { - path - } else { - paths - .first() - .ok_or_else(|| PyValueError::new_err("expected a path argument"))? - }; + let sources = sources.0; + let (first_path, sources) = match source { + None => (sources.first_path().map(|p| p.to_path_buf()), sources), + Some(source) => pyobject_to_first_path_and_scan_sources(source)?, + }; - let first_path_url = first_path.to_string_lossy(); + let mut r = LazyJsonLineReader::new_with_sources(sources); - let mut cloud_options = if let Some(opts) = cloud_options { - parse_cloud_options(&first_path_url, opts)? - } else { - parse_cloud_options(&first_path_url, vec![])? - }; + #[cfg(feature = "cloud")] + if let Some(first_path) = first_path { + let first_path_url = first_path.to_string_lossy(); + let mut cloud_options = + parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; cloud_options = cloud_options.with_max_retries(retries); if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; } - Some(cloud_options) - }; - - let r = if let Some(path) = &path { - LazyJsonLineReader::new(path) - } else { - LazyJsonLineReader::new_paths(paths.into()) + r = r.with_cloud_options(Some(cloud_options)); }; let lf = r @@ -113,8 +98,7 @@ impl PyLazyFrame { .with_schema_overwrite(schema_overrides.map(|x| Arc::new(x.0))) .with_row_index(row_index) .with_ignore_errors(ignore_errors) - .with_include_file_paths(include_file_paths.map(Arc::from)) - .with_cloud_options(cloud_options) + .with_include_file_paths(include_file_paths.map(|x| x.into())) .finish() .map_err(PyPolarsErr::from)?; @@ -123,7 +107,7 @@ impl PyLazyFrame { #[staticmethod] #[cfg(feature = "csv")] - #[pyo3(signature = (path, paths, separator, has_header, ignore_errors, skip_rows, n_rows, cache, overwrite_dtype, + #[pyo3(signature = (source, sources, separator, has_header, ignore_errors, skip_rows, n_rows, cache, overwrite_dtype, low_memory, comment_prefix, quote_char, null_values, missing_utf8_is_empty_string, infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header, encoding, row_index, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, glob, schema, @@ -131,8 +115,8 @@ impl PyLazyFrame { ) )] fn new_from_csv( - path: Option, - paths: Vec, + source: Option, + sources: Wrap, separator: &str, has_header: bool, ignore_errors: bool, @@ -168,49 +152,37 @@ impl PyLazyFrame { let separator = separator.as_bytes()[0]; let eol_char = eol_char.as_bytes()[0]; let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); let overwrite_dtype = overwrite_dtype.map(|overwrite_dtype| { overwrite_dtype .into_iter() - .map(|(name, dtype)| Field::new(&name, dtype.0)) + .map(|(name, dtype)| Field::new((&*name).into(), dtype.0)) .collect::() }); - #[cfg(feature = "cloud")] - let cloud_options = { - let first_path = if let Some(path) = &path { - path - } else { - paths - .first() - .ok_or_else(|| PyValueError::new_err("expected a path argument"))? - }; + let sources = sources.0; + let (first_path, sources) = match source { + None => (sources.first_path().map(|p| p.to_path_buf()), sources), + Some(source) => pyobject_to_first_path_and_scan_sources(source)?, + }; - let first_path_url = first_path.to_string_lossy(); + let mut r = LazyCsvReader::new_with_sources(sources); - let mut cloud_options = if let Some(opts) = cloud_options { - parse_cloud_options(&first_path_url, opts)? - } else { - parse_cloud_options(&first_path_url, vec![])? - }; - - cloud_options = cloud_options.with_max_retries(retries); + #[cfg(feature = "cloud")] + if let Some(first_path) = first_path { + let first_path_url = first_path.to_string_lossy(); + let mut cloud_options = + parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; if let Some(file_cache_ttl) = file_cache_ttl { cloud_options.file_cache_ttl = file_cache_ttl; } - - Some(cloud_options) - }; - - let r = if let Some(path) = path.as_ref() { - LazyCsvReader::new(path) - } else { - LazyCsvReader::new_paths(paths.into()) - }; + cloud_options = cloud_options.with_max_retries(retries); + r = r.with_cloud_options(Some(cloud_options)); + } let mut r = r .with_infer_schema_length(infer_schema_length) @@ -223,7 +195,7 @@ impl PyLazyFrame { .with_dtype_overwrite(overwrite_dtype.map(Arc::new)) .with_schema(schema.map(|schema| Arc::new(schema.0))) .with_low_memory(low_memory) - .with_comment_prefix(comment_prefix) + .with_comment_prefix(comment_prefix.map(|x| x.into())) .with_quote_char(quote_char) .with_eol_char(eol_char) .with_rechunk(rechunk) @@ -237,8 +209,7 @@ impl PyLazyFrame { .with_decimal_comma(decimal_comma) .with_glob(glob) .with_raise_if_empty(raise_if_empty) - .with_cloud_options(cloud_options) - .with_include_file_paths(include_file_paths.map(Arc::from)); + .with_include_file_paths(include_file_paths.map(|x| x.into())); if let Some(lambda) = with_schema_modify { let f = |schema: Schema| { @@ -254,9 +225,9 @@ impl PyLazyFrame { ShapeMismatch: "The length of the new names list should be equal to or less than the original column length", ); Ok(schema - .iter_dtypes() + .iter_values() .zip(new_names) - .map(|(dtype, name)| Field::from_owned(name.into(), dtype.clone())) + .map(|(dtype, name)| Field::new(name.into(), dtype.clone())) .collect()) }) }; @@ -268,12 +239,12 @@ impl PyLazyFrame { #[cfg(feature = "parquet")] #[staticmethod] - #[pyo3(signature = (path, paths, n_rows, cache, parallel, rechunk, row_index, + #[pyo3(signature = (source, sources, n_rows, cache, parallel, rechunk, row_index, low_memory, cloud_options, use_statistics, hive_partitioning, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths) )] fn new_from_parquet( - path: Option, - paths: Vec, + source: Option, + sources: Wrap, n_rows: Option, cache: bool, parallel: Wrap, @@ -292,33 +263,11 @@ impl PyLazyFrame { let parallel = parallel.0; let hive_schema = hive_schema.map(|s| Arc::new(s.0)); - let first_path = if let Some(path) = &path { - path - } else { - paths - .first() - .ok_or_else(|| PyValueError::new_err("expected a path argument"))? - }; - - #[cfg(feature = "cloud")] - let cloud_options = { - let first_path_url = first_path.to_string_lossy(); - - let mut cloud_options = if let Some(opts) = cloud_options { - parse_cloud_options(&first_path_url, opts)? - } else { - parse_cloud_options(&first_path_url, vec![])? - }; - - cloud_options = cloud_options.with_max_retries(retries); - - Some(cloud_options) - }; - let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); + let hive_options = HiveOptions { enabled: hive_partitioning, hive_start_idx: 0, @@ -326,40 +275,49 @@ impl PyLazyFrame { try_parse_dates: try_parse_hive_dates, }; - let args = ScanArgsParquet { + let mut args = ScanArgsParquet { n_rows, cache, parallel, rechunk, row_index, low_memory, - cloud_options, + cloud_options: None, use_statistics, hive_options, glob, - include_file_paths: include_file_paths.map(Arc::from), + include_file_paths: include_file_paths.map(|x| x.into()), }; - let lf = if path.is_some() { - LazyFrame::scan_parquet(first_path, args) - } else { - LazyFrame::scan_parquet_files(Arc::from(paths), args) + let sources = sources.0; + let (first_path, sources) = match source { + None => (sources.first_path().map(|p| p.to_path_buf()), sources), + Some(source) => pyobject_to_first_path_and_scan_sources(source)?, + }; + + #[cfg(feature = "cloud")] + if let Some(first_path) = first_path { + let first_path_url = first_path.to_string_lossy(); + let cloud_options = + parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; + args.cloud_options = Some(cloud_options.with_max_retries(retries)); } - .map_err(PyPolarsErr::from)?; + + let lf = LazyFrame::scan_parquet_sources(sources, args).map_err(PyPolarsErr::from)?; + Ok(lf.into()) } #[cfg(feature = "ipc")] #[staticmethod] - #[pyo3(signature = (path, paths, n_rows, cache, rechunk, row_index, memory_map, cloud_options, hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, include_file_paths))] + #[pyo3(signature = (source, sources, n_rows, cache, rechunk, row_index, cloud_options, hive_partitioning, hive_schema, try_parse_hive_dates, retries, file_cache_ttl, include_file_paths))] fn new_from_ipc( - path: Option, - paths: Vec, + source: Option, + sources: Wrap, n_rows: Option, cache: bool, rechunk: bool, row_index: Option<(String, IdxSize)>, - memory_map: bool, cloud_options: Option>, hive_partitioning: Option, hive_schema: Option>, @@ -369,37 +327,10 @@ impl PyLazyFrame { include_file_paths: Option, ) -> PyResult { let row_index = row_index.map(|(name, offset)| RowIndex { - name: Arc::from(name.as_str()), + name: name.into(), offset, }); - #[cfg(feature = "cloud")] - let cloud_options = { - let first_path = if let Some(path) = &path { - path - } else { - paths - .first() - .ok_or_else(|| PyValueError::new_err("expected a path argument"))? - }; - - let first_path_url = first_path.to_string_lossy(); - - let mut cloud_options = if let Some(opts) = cloud_options { - parse_cloud_options(&first_path_url, opts)? - } else { - parse_cloud_options(&first_path_url, vec![])? - }; - - cloud_options = cloud_options.with_max_retries(retries); - - if let Some(file_cache_ttl) = file_cache_ttl { - cloud_options.file_cache_ttl = file_cache_ttl; - } - - Some(cloud_options) - }; - let hive_options = HiveOptions { enabled: hive_partitioning, hive_start_idx: 0, @@ -407,24 +338,36 @@ impl PyLazyFrame { try_parse_dates: try_parse_hive_dates, }; - let args = ScanArgsIpc { + let mut args = ScanArgsIpc { n_rows, cache, rechunk, row_index, - memory_map, #[cfg(feature = "cloud")] - cloud_options, + cloud_options: None, hive_options, - include_file_paths: include_file_paths.map(Arc::from), + include_file_paths: include_file_paths.map(|x| x.into()), }; - let lf = if let Some(path) = &path { - LazyFrame::scan_ipc(path, args) - } else { - LazyFrame::scan_ipc_files(paths.into(), args) + let sources = sources.0; + let (first_path, sources) = match source { + None => (sources.first_path().map(|p| p.to_path_buf()), sources), + Some(source) => pyobject_to_first_path_and_scan_sources(source)?, + }; + + #[cfg(feature = "cloud")] + if let Some(first_path) = first_path { + let first_path_url = first_path.to_string_lossy(); + + let mut cloud_options = + parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; + if let Some(file_cache_ttl) = file_cache_ttl { + cloud_options.file_cache_ttl = file_cache_ttl; + } + args.cloud_options = Some(cloud_options.with_max_retries(retries)); } - .map_err(PyPolarsErr::from)?; + + let lf = LazyFrame::scan_ipc_sources(sources, args).map_err(PyPolarsErr::from)?; Ok(lf.into()) } @@ -444,8 +387,11 @@ impl PyLazyFrame { scan_fn: PyObject, pyarrow: bool, ) -> PyResult { - let schema = - Schema::from_iter(schema.into_iter().map(|(name, dt)| Field::new(&name, dt.0))); + let schema = Schema::from_iter( + schema + .into_iter() + .map(|(name, dt)| Field::new((&*name).into(), dt.0)), + ); Ok(LazyFrame::scan_from_python_function(schema, scan_fn, pyarrow).into()) } @@ -503,10 +449,14 @@ impl PyLazyFrame { .with_simplify_expr(simplify_expression) .with_slice_pushdown(slice_pushdown) .with_cluster_with_columns(cluster_with_columns) - .with_streaming(streaming) ._with_eager(_eager) .with_projection_pushdown(projection_pushdown); + #[cfg(feature = "streaming")] + { + ldf = ldf.with_streaming(streaming); + } + #[cfg(feature = "new_streaming")] { ldf = ldf.with_new_streaming(new_streaming); @@ -908,12 +858,12 @@ impl PyLazyFrame { strategy: Wrap, tolerance: Option>>, tolerance_str: Option, - coalesce: Option, + coalesce: bool, ) -> PyResult { - let coalesce = match coalesce { - None => JoinCoalesce::JoinSpecific, - Some(true) => JoinCoalesce::CoalesceColumns, - Some(false) => JoinCoalesce::KeepColumns, + let coalesce = if coalesce { + JoinCoalesce::CoalesceColumns + } else { + JoinCoalesce::KeepColumns }; let ldf = self.ldf.clone(); let other = other.ldf; @@ -929,8 +879,8 @@ impl PyLazyFrame { .coalesce(coalesce) .how(JoinType::AsOf(AsOfOptions { strategy: strategy.0, - left_by: left_by.map(strings_to_smartstrings), - right_by: right_by.map(strings_to_smartstrings), + left_by: left_by.map(strings_to_pl_smallstr), + right_by: right_by.map(strings_to_pl_smallstr), tolerance: tolerance.map(|t| t.0.into_static().unwrap()), tolerance_str: tolerance_str.map(|s| s.into()), })) @@ -984,6 +934,20 @@ impl PyLazyFrame { .into()) } + fn join_where(&self, other: Self, predicates: Vec, suffix: String) -> PyResult { + let ldf = self.ldf.clone(); + let other = other.ldf; + + let predicates = predicates.to_exprs(); + + Ok(ldf + .join_builder() + .with(other) + .suffix(suffix) + .join_where(predicates) + .into()) + } + fn with_columns(&mut self, exprs: Vec) -> Self { let ldf = self.ldf.clone(); ldf.with_columns(exprs.to_exprs()).into() @@ -1081,21 +1045,22 @@ impl PyLazyFrame { fn unique( &self, maintain_order: bool, - subset: Option>, + subset: Option>, keep: Wrap, ) -> Self { let ldf = self.ldf.clone(); + let subset = subset.map(|e| e.to_exprs()); match maintain_order { - true => ldf.unique_stable(subset, keep.0), - false => ldf.unique(subset, keep.0), + true => ldf.unique_stable_generic(subset, keep.0), + false => ldf.unique_generic(subset, keep.0), } .into() } - fn drop_nulls(&self, subset: Option>) -> Self { + fn drop_nulls(&self, subset: Option>) -> Self { let ldf = self.ldf.clone(); - ldf.drop_nulls(subset.map(|v| v.into_iter().map(|s| col(&s)).collect())) - .into() + let subset = subset.map(|e| e.to_exprs()); + ldf.drop_nulls(subset).into() } fn slice(&self, offset: i64, len: Option) -> Self { @@ -1108,21 +1073,19 @@ impl PyLazyFrame { ldf.tail(n).into() } - #[pyo3(signature = (on, index, value_name, variable_name, streamable))] + #[pyo3(signature = (on, index, value_name, variable_name))] fn unpivot( &self, - on: Vec, - index: Vec, + on: Vec, + index: Vec, value_name: Option, variable_name: Option, - streamable: bool, ) -> Self { - let args = UnpivotArgs { - on: strings_to_smartstrings(on), - index: strings_to_smartstrings(index), + let args = UnpivotArgsDSL { + on: on.into_iter().map(|e| e.inner.into()).collect(), + index: index.into_iter().map(|e| e.inner.into()).collect(), value_name: value_name.map(|s| s.into()), variable_name: variable_name.map(|s| s.into()), - streamable, }; let ldf = self.ldf.clone(); @@ -1145,11 +1108,11 @@ impl PyLazyFrame { schema: Option>, validate_output: bool, ) -> Self { - let mut opt = OptState::default(); - opt.set(OptState::PREDICATE_PUSHDOWN, predicate_pushdown); - opt.set(OptState::PROJECTION_PUSHDOWN, projection_pushdown); - opt.set(OptState::SLICE_PUSHDOWN, slice_pushdown); - opt.set(OptState::STREAMING, streamable); + let mut opt = OptFlags::default(); + opt.set(OptFlags::PREDICATE_PUSHDOWN, predicate_pushdown); + opt.set(OptFlags::PROJECTION_PUSHDOWN, projection_pushdown); + opt.set(OptFlags::SLICE_PUSHDOWN, slice_pushdown); + opt.set(OptFlags::STREAMING, streamable); self.ldf .clone() @@ -1189,19 +1152,20 @@ impl PyLazyFrame { fn collect_schema(&mut self, py: Python) -> PyResult { let schema = py - .allow_threads(|| self.ldf.schema()) + .allow_threads(|| self.ldf.collect_schema()) .map_err(PyPolarsErr::from)?; let schema_dict = PyDict::new_bound(py); schema.iter_fields().for_each(|fld| { schema_dict - .set_item(fld.name().as_str(), Wrap(fld.data_type().clone())) + .set_item(fld.name().as_str(), Wrap(fld.dtype().clone())) .unwrap() }); Ok(schema_dict.to_object(py)) } - fn unnest(&self, columns: Vec) -> Self { + fn unnest(&self, columns: Vec) -> Self { + let columns = columns.to_exprs(); self.ldf.clone().unnest(columns).into() } diff --git a/crates/polars-python/src/lazyframe/mod.rs b/crates/polars-python/src/lazyframe/mod.rs new file mode 100644 index 000000000000..85433332e415 --- /dev/null +++ b/crates/polars-python/src/lazyframe/mod.rs @@ -0,0 +1,24 @@ +mod exitable; +#[cfg(feature = "pymethods")] +mod general; +#[cfg(feature = "pymethods")] +mod serde; +pub mod visit; +pub mod visitor; + +pub use exitable::PyInProcessQuery; +use polars::prelude::LazyFrame; +use pyo3::pyclass; + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PyLazyFrame { + pub ldf: LazyFrame, +} + +impl From for PyLazyFrame { + fn from(ldf: LazyFrame) -> Self { + PyLazyFrame { ldf } + } +} diff --git a/py-polars/src/lazyframe/serde.rs b/crates/polars-python/src/lazyframe/serde.rs similarity index 100% rename from py-polars/src/lazyframe/serde.rs rename to crates/polars-python/src/lazyframe/serde.rs diff --git a/py-polars/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs similarity index 92% rename from py-polars/src/lazyframe/visit.rs rename to crates/polars-python/src/lazyframe/visit.rs index 874d646cba61..4a24261363f3 100644 --- a/py-polars/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -1,18 +1,21 @@ -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; +use polars::prelude::PolarsError; use polars_plan::plans::{to_aexpr, Context, IR}; use polars_plan::prelude::expr_ir::ExprIR; use polars_plan::prelude::{AExpr, PythonOptions, PythonScanSource}; use polars_utils::arena::{Arena, Node}; use pyo3::prelude::*; -use visitor::{expr_nodes, nodes}; +use pyo3::types::PyList; -use super::*; -use crate::raise_err; +use super::visitor::{expr_nodes, nodes}; +use super::PyLazyFrame; +use crate::error::PyPolarsErr; +use crate::{raise_err, PyExpr, Wrap}; #[derive(Clone)] #[pyclass] -pub(crate) struct PyExprIR { +pub struct PyExprIR { #[pyo3(get)] node: usize, #[pyo3(get)] @@ -23,7 +26,7 @@ impl From for PyExprIR { fn from(value: ExprIR) -> Self { Self { node: value.node().0, - output_name: value.output_name().into(), + output_name: value.output_name().to_string(), } } } @@ -32,7 +35,7 @@ impl From<&ExprIR> for PyExprIR { fn from(value: &ExprIR) -> Self { Self { node: value.node().0, - output_name: value.output_name().into(), + output_name: value.output_name().to_string(), } } } @@ -54,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (1, 0); + const VERSION: Version = (2, 0); pub(crate) fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { @@ -195,8 +198,12 @@ impl NodeTraverser { Ok(( expressions .into_iter() - .map(|e| to_aexpr(e.inner, &mut expr_arena).0) - .collect(), + .map(|e| { + to_aexpr(e.inner, &mut expr_arena) + .map_err(PyPolarsErr::from) + .map(|v| v.0) + }) + .collect::>()?, expr_arena.len(), )) } diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs similarity index 87% rename from py-polars/src/lazyframe/visitor/expr_nodes.rs rename to crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 813e20fcb5ae..ce95204056bd 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1,10 +1,13 @@ use polars::datatypes::TimeUnit; +use polars::series::ops::NullBehavior; use polars_core::prelude::{NonExistent, QuantileInterpolOptions}; use polars_core::series::IsSorted; use polars_ops::prelude::ClosedInterval; +use polars_ops::series::InterpolationMethod; +#[cfg(feature = "search_sorted")] +use polars_ops::series::SearchSortedSide; use polars_plan::dsl::function_expr::rolling::RollingFunction; use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy; -use polars_plan::dsl::function_expr::trigonometry::TrigonometricFunction; use polars_plan::dsl::{BooleanFunction, StringFunction, TemporalFunction}; use polars_plan::prelude::{ AExpr, FunctionExpr, GroupbyOptions, IRAggExpr, LiteralValue, Operator, PowFunction, @@ -585,11 +588,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { .into_py(py), AExpr::Cast { expr, - data_type, + dtype, options, } => Cast { expr: expr.0, - dtype: Wrap(data_type.clone()).to_object(py), + dtype: Wrap(dtype.clone()).to_object(py), options: *options as u8, } .into_py(py), @@ -760,7 +763,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { ignore_nulls, } => ( PyStringFunction::ConcatHorizontal.into_py(py), - delimiter, + delimiter.as_str(), ignore_nulls, ) .to_object(py), @@ -769,10 +772,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { ignore_nulls, } => ( PyStringFunction::ConcatVertical.into_py(py), - delimiter, + delimiter.as_str(), ignore_nulls, ) .to_object(py), + #[cfg(feature = "regex")] StringFunction::Contains { literal, strict } => { (PyStringFunction::Contains.into_py(py), literal, strict).to_object(py) }, @@ -792,9 +796,10 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::ExtractGroups { dtype, pat } => ( PyStringFunction::ExtractGroups.into_py(py), Wrap(dtype.clone()).to_object(py), - pat, + pat.as_str(), ) .to_object(py), + #[cfg(feature = "regex")] StringFunction::Find { literal, strict } => { (PyStringFunction::Find.into_py(py), literal, strict).to_object(py) }, @@ -819,6 +824,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::JsonPathMatch => { (PyStringFunction::JsonPathMatch.into_py(py),).to_object(py) }, + #[cfg(feature = "regex")] StringFunction::Replace { n, literal } => { (PyStringFunction::Replace.into_py(py), n, literal).to_object(py) }, @@ -837,12 +843,14 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::HexEncode => { (PyStringFunction::HexEncode.into_py(py),).to_object(py) }, + #[cfg(feature = "binary_encoding")] StringFunction::HexDecode(strict) => { (PyStringFunction::HexDecode.into_py(py), strict).to_object(py) }, StringFunction::Base64Encode => { (PyStringFunction::Base64Encode.into_py(py),).to_object(py) }, + #[cfg(feature = "binary_encoding")] StringFunction::Base64Decode(strict) => { (PyStringFunction::Base64Decode.into_py(py), strict).to_object(py) }, @@ -887,6 +895,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { StringFunction::ToDecimal(inference_length) => { (PyStringFunction::ToDecimal.into_py(py), inference_length).to_object(py) }, + #[cfg(feature = "nightly")] StringFunction::Titlecase => { (PyStringFunction::Titlecase.into_py(py),).to_object(py) }, @@ -968,8 +977,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { TemporalFunction::WithTimeUnit(time_unit) => { (PyTemporalFunction::WithTimeUnit, Wrap(*time_unit)).into_py(py) }, + #[cfg(feature = "timezones")] TemporalFunction::ConvertTimeZone(time_zone) => { - (PyTemporalFunction::ConvertTimeZone, time_zone).into_py(py) + (PyTemporalFunction::ConvertTimeZone, time_zone.as_str()).into_py(py) }, TemporalFunction::TimeStamp(time_unit) => { (PyTemporalFunction::TimeStamp, Wrap(*time_unit)).into_py(py) @@ -978,11 +988,14 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { TemporalFunction::OffsetBy => (PyTemporalFunction::OffsetBy,).into_py(py), TemporalFunction::MonthStart => (PyTemporalFunction::MonthStart,).into_py(py), TemporalFunction::MonthEnd => (PyTemporalFunction::MonthEnd,).into_py(py), + #[cfg(feature = "timezones")] TemporalFunction::BaseUtcOffset => { (PyTemporalFunction::BaseUtcOffset,).into_py(py) }, + #[cfg(feature = "timezones")] TemporalFunction::DSTOffset => (PyTemporalFunction::DSTOffset,).into_py(py), TemporalFunction::Round => (PyTemporalFunction::Round,).into_py(py), + #[cfg(feature = "timezones")] TemporalFunction::ReplaceTimeZone(time_zone, non_existent) => ( PyTemporalFunction::ReplaceTimeZone, time_zone @@ -1033,6 +1046,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { BooleanFunction::IsBetween { closed } => { (PyBooleanFunction::IsBetween, Wrap(*closed)).into_py(py) }, + #[cfg(feature = "is_in")] BooleanFunction::IsIn => (PyBooleanFunction::IsIn,).into_py(py), BooleanFunction::AllHorizontal => { (PyBooleanFunction::AllHorizontal,).into_py(py) @@ -1044,41 +1058,58 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { }, FunctionExpr::Abs => ("abs",).to_object(py), #[cfg(feature = "hist")] - FunctionExpr::Hist { .. } => return Err(PyNotImplementedError::new_err("hist")), + FunctionExpr::Hist { + bin_count, + include_category, + include_breakpoint, + } => ("hist", bin_count, include_category, include_breakpoint).to_object(py), FunctionExpr::NullCount => ("null_count",).to_object(py), FunctionExpr::Pow(f) => match f { PowFunction::Generic => ("pow",).to_object(py), PowFunction::Sqrt => ("sqrt",).to_object(py), PowFunction::Cbrt => ("cbrt",).to_object(py), }, - FunctionExpr::Hash(_, _, _, _) => { - return Err(PyNotImplementedError::new_err("hash")) + FunctionExpr::Hash(seed, seed_1, seed_2, seed_3) => { + ("hash", seed, seed_1, seed_2, seed_3).to_object(py) }, FunctionExpr::ArgWhere => ("argwhere",).to_object(py), #[cfg(feature = "search_sorted")] - FunctionExpr::SearchSorted(_) => { - return Err(PyNotImplementedError::new_err("search sorted")) - }, + FunctionExpr::SearchSorted(side) => ( + "search_sorted", + match side { + SearchSortedSide::Any => "any", + SearchSortedSide::Left => "left", + SearchSortedSide::Right => "right", + }, + ) + .to_object(py), FunctionExpr::Range(_) => return Err(PyNotImplementedError::new_err("range")), - FunctionExpr::Trigonometry(trigfun) => match trigfun { - TrigonometricFunction::Cos => ("cos",), - TrigonometricFunction::Cot => ("cot",), - TrigonometricFunction::Sin => ("sin",), - TrigonometricFunction::Tan => ("tan",), - TrigonometricFunction::ArcCos => ("arccos",), - TrigonometricFunction::ArcSin => ("arcsin",), - TrigonometricFunction::ArcTan => ("arctan",), - TrigonometricFunction::Cosh => ("cosh",), - TrigonometricFunction::Sinh => ("sinh",), - TrigonometricFunction::Tanh => ("tanh",), - TrigonometricFunction::ArcCosh => ("arccosh",), - TrigonometricFunction::ArcSinh => ("arcsinh",), - TrigonometricFunction::ArcTanh => ("arctanh",), - TrigonometricFunction::Degrees => ("degrees",), - TrigonometricFunction::Radians => ("radians",), - } - .to_object(py), + #[cfg(feature = "trigonometry")] + FunctionExpr::Trigonometry(trigfun) => { + use polars_plan::dsl::function_expr::trigonometry::TrigonometricFunction; + + match trigfun { + TrigonometricFunction::Cos => ("cos",), + TrigonometricFunction::Cot => ("cot",), + TrigonometricFunction::Sin => ("sin",), + TrigonometricFunction::Tan => ("tan",), + TrigonometricFunction::ArcCos => ("arccos",), + TrigonometricFunction::ArcSin => ("arcsin",), + TrigonometricFunction::ArcTan => ("arctan",), + TrigonometricFunction::Cosh => ("cosh",), + TrigonometricFunction::Sinh => ("sinh",), + TrigonometricFunction::Tanh => ("tanh",), + TrigonometricFunction::ArcCosh => ("arccosh",), + TrigonometricFunction::ArcSinh => ("arcsinh",), + TrigonometricFunction::ArcTanh => ("arctanh",), + TrigonometricFunction::Degrees => ("degrees",), + TrigonometricFunction::Radians => ("radians",), + } + .to_object(py) + }, + #[cfg(feature = "trigonometry")] FunctionExpr::Atan2 => ("atan2",).to_object(py), + #[cfg(feature = "sign")] FunctionExpr::Sign => ("sign",).to_object(py), FunctionExpr::FillNull => ("fill_null",).to_object(py), FunctionExpr::RollingExpr(rolling) => match rolling { @@ -1130,17 +1161,13 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { return Err(PyNotImplementedError::new_err("rolling std by")) }, }, - FunctionExpr::ShiftAndFill => { - return Err(PyNotImplementedError::new_err("shift and fill")) - }, + FunctionExpr::ShiftAndFill => ("shift_and_fill",).to_object(py), FunctionExpr::Shift => ("shift",).to_object(py), FunctionExpr::DropNans => ("drop_nans",).to_object(py), FunctionExpr::DropNulls => ("drop_nulls",).to_object(py), FunctionExpr::Mode => ("mode",).to_object(py), - FunctionExpr::Skew(_) => return Err(PyNotImplementedError::new_err("skew")), - FunctionExpr::Kurtosis(_, _) => { - return Err(PyNotImplementedError::new_err("kurtosis")) - }, + FunctionExpr::Skew(bias) => ("skew", bias).to_object(py), + FunctionExpr::Kurtosis(fisher, bias) => ("kurtosis", fisher, bias).to_object(py), FunctionExpr::Reshape(_, _) => { return Err(PyNotImplementedError::new_err("reshape")) }, @@ -1151,11 +1178,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { options: _, seed: _, } => return Err(PyNotImplementedError::new_err("rank")), - FunctionExpr::Clip { - has_min: _, - has_max: _, - } => return Err(PyNotImplementedError::new_err("clip")), - FunctionExpr::AsStruct => return Err(PyNotImplementedError::new_err("as struct")), + FunctionExpr::Clip { has_min, has_max } => ("clip", has_min, has_max).to_object(py), + FunctionExpr::AsStruct => ("as_struct",).to_object(py), #[cfg(feature = "top_k")] FunctionExpr::TopK { descending } => ("top_k", descending).to_object(py), FunctionExpr::CumCount { reverse } => ("cum_count", reverse).to_object(py), @@ -1165,37 +1189,41 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::CumMax { reverse } => ("cum_max", reverse).to_object(py), FunctionExpr::Reverse => ("reverse",).to_object(py), FunctionExpr::ValueCounts { - sort: _, - parallel: _, - name: _, - normalize: _, - } => return Err(PyNotImplementedError::new_err("value counts")), + sort, + parallel, + name, + normalize, + } => ("value_counts", sort, parallel, name.as_str(), normalize).to_object(py), FunctionExpr::UniqueCounts => ("unique_counts",).to_object(py), - FunctionExpr::ApproxNUnique => { - return Err(PyNotImplementedError::new_err("approx nunique")) - }, + FunctionExpr::ApproxNUnique => ("approx_n_unique",).to_object(py), FunctionExpr::Coalesce => ("coalesce",).to_object(py), - FunctionExpr::ShrinkType => { - return Err(PyNotImplementedError::new_err("shrink type")) - }, - FunctionExpr::Diff(_, _) => return Err(PyNotImplementedError::new_err("diff")), + FunctionExpr::ShrinkType => ("shrink_dtype",).to_object(py), + FunctionExpr::Diff(n, null_behaviour) => ( + "diff", + n, + match null_behaviour { + NullBehavior::Drop => "drop", + NullBehavior::Ignore => "ignore", + }, + ) + .to_object(py), #[cfg(feature = "pct_change")] - FunctionExpr::PctChange => { - return Err(PyNotImplementedError::new_err("pct change")) - }, - FunctionExpr::Interpolate(_) => { - return Err(PyNotImplementedError::new_err("interpolate")) - }, - FunctionExpr::InterpolateBy => { - return Err(PyNotImplementedError::new_err("interpolate_by")) + FunctionExpr::PctChange => ("pct_change",).to_object(py), + FunctionExpr::Interpolate(method) => ( + "interpolate", + match method { + InterpolationMethod::Linear => "linear", + InterpolationMethod::Nearest => "nearest", + }, + ) + .to_object(py), + FunctionExpr::InterpolateBy => ("interpolate_by",).to_object(py), + FunctionExpr::Entropy { base, normalize } => { + ("entropy", base, normalize).to_object(py) }, - FunctionExpr::Entropy { - base: _, - normalize: _, - } => return Err(PyNotImplementedError::new_err("entropy")), - FunctionExpr::Log { base: _ } => return Err(PyNotImplementedError::new_err("log")), - FunctionExpr::Log1p => return Err(PyNotImplementedError::new_err("log1p")), - FunctionExpr::Exp => return Err(PyNotImplementedError::new_err("exp")), + FunctionExpr::Log { base } => ("log", base).to_object(py), + FunctionExpr::Log1p => ("log1p",).to_object(py), + FunctionExpr::Exp => ("exp",).to_object(py), FunctionExpr::Unique(maintain_order) => ("unique", maintain_order).to_object(py), FunctionExpr::Round { decimals } => ("round", decimals).to_object(py), FunctionExpr::RoundSF { digits } => ("round_sig_figs", digits).to_object(py), @@ -1211,20 +1239,18 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { return Err(PyNotImplementedError::new_err("corr")) }, #[cfg(feature = "peaks")] - FunctionExpr::PeakMin => return Err(PyNotImplementedError::new_err("peak min")), + FunctionExpr::PeakMin => ("peak_max",).to_object(py), #[cfg(feature = "peaks")] - FunctionExpr::PeakMax => return Err(PyNotImplementedError::new_err("peak max")), + FunctionExpr::PeakMax => ("peak_min",).to_object(py), #[cfg(feature = "cutqcut")] FunctionExpr::Cut { .. } => return Err(PyNotImplementedError::new_err("cut")), #[cfg(feature = "cutqcut")] FunctionExpr::QCut { .. } => return Err(PyNotImplementedError::new_err("qcut")), #[cfg(feature = "rle")] - FunctionExpr::RLE => return Err(PyNotImplementedError::new_err("rle")), + FunctionExpr::RLE => ("rle",).to_object(py), #[cfg(feature = "rle")] - FunctionExpr::RLEID => return Err(PyNotImplementedError::new_err("rleid")), - FunctionExpr::ToPhysical => { - return Err(PyNotImplementedError::new_err("to physical")) - }, + FunctionExpr::RLEID => ("rle_id",).to_object(py), + FunctionExpr::ToPhysical => ("to_physical",).to_object(py), FunctionExpr::Random { .. } => { return Err(PyNotImplementedError::new_err("random")) }, @@ -1241,24 +1267,12 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::FfiPlugin { .. } => { return Err(PyNotImplementedError::new_err("ffi plugin")) }, - FunctionExpr::BackwardFill { limit: _ } => { - return Err(PyNotImplementedError::new_err("backward fill")) - }, - FunctionExpr::ForwardFill { limit: _ } => { - return Err(PyNotImplementedError::new_err("forward fill")) - }, - FunctionExpr::SumHorizontal => { - return Err(PyNotImplementedError::new_err("sum horizontal")) - }, - FunctionExpr::MaxHorizontal => { - return Err(PyNotImplementedError::new_err("max horizontal")) - }, - FunctionExpr::MeanHorizontal => { - return Err(PyNotImplementedError::new_err("mean horizontal")) - }, - FunctionExpr::MinHorizontal => { - return Err(PyNotImplementedError::new_err("min horizontal")) - }, + FunctionExpr::BackwardFill { limit } => ("backward_fill", limit).to_object(py), + FunctionExpr::ForwardFill { limit } => ("forward_fill", limit).to_object(py), + FunctionExpr::SumHorizontal => ("sum_horizontal",).to_object(py), + FunctionExpr::MaxHorizontal => ("max_horizontal",).to_object(py), + FunctionExpr::MeanHorizontal => ("mean_horizontal",).to_object(py), + FunctionExpr::MinHorizontal => ("min_horizontal",).to_object(py), FunctionExpr::EwmMean { options: _ } => { return Err(PyNotImplementedError::new_err("ewm mean")) }, @@ -1268,23 +1282,20 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::EwmVar { options: _ } => { return Err(PyNotImplementedError::new_err("ewm var")) }, - FunctionExpr::Replace => return Err(PyNotImplementedError::new_err("replace")), + FunctionExpr::Replace => ("replace",).to_object(py), FunctionExpr::ReplaceStrict { return_dtype: _ } => { - return Err(PyNotImplementedError::new_err("replace_strict")) + // Can ignore the return dtype because it is encoded in the schema. + ("replace_strict",).to_object(py) }, - FunctionExpr::Negate => return Err(PyNotImplementedError::new_err("negate")), + FunctionExpr::Negate => ("negate",).to_object(py), FunctionExpr::FillNullWithStrategy(_) => { return Err(PyNotImplementedError::new_err("fill null with strategy")) }, FunctionExpr::GatherEvery { n, offset } => { ("gather_every", offset, n).to_object(py) }, - FunctionExpr::Reinterpret(_) => { - return Err(PyNotImplementedError::new_err("reinterpret")) - }, - FunctionExpr::ExtendConstant => { - return Err(PyNotImplementedError::new_err("extend constant")) - }, + FunctionExpr::Reinterpret(signed) => ("reinterpret", signed).to_object(py), + FunctionExpr::ExtendConstant => ("extend_constant",).to_object(py), FunctionExpr::Business(_) => { return Err(PyNotImplementedError::new_err("business")) }, @@ -1330,7 +1341,6 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { } .into_py(py) }, - AExpr::Wildcard => return Err(PyNotImplementedError::new_err("wildcard")), AExpr::Slice { input, offset, @@ -1341,7 +1351,6 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { length: length.0, } .into_py(py), - AExpr::Nth(_) => return Err(PyNotImplementedError::new_err("nth")), AExpr::Len => Len {}.into_py(py), }; Ok(result) diff --git a/crates/polars-python/src/lazyframe/visitor/mod.rs b/crates/polars-python/src/lazyframe/visitor/mod.rs new file mode 100644 index 000000000000..39af9e064b73 --- /dev/null +++ b/crates/polars-python/src/lazyframe/visitor/mod.rs @@ -0,0 +1,2 @@ +pub mod expr_nodes; +pub mod nodes; diff --git a/py-polars/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs similarity index 88% rename from py-polars/src/lazyframe/visitor/nodes.rs rename to crates/polars-python/src/lazyframe/visitor/nodes.rs index e08b3bfb37a1..4e9344a61d15 100644 --- a/py-polars/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -2,13 +2,13 @@ use polars_core::prelude::{IdxSize, UniqueKeepStrategy}; use polars_ops::prelude::JoinType; use polars_plan::plans::IR; use polars_plan::prelude::{ - FileCount, FileScan, FileScanOptions, FunctionNode, PythonPredicate, PythonScanSource, + FileCount, FileScan, FileScanOptions, FunctionIR, PythonPredicate, PythonScanSource, }; use pyo3::exceptions::{PyNotImplementedError, PyValueError}; use pyo3::prelude::*; -use super::super::visit::PyExprIR; use super::expr_nodes::PyGroupbyOptions; +use crate::lazyframe::visit::PyExprIR; use crate::PyDataFrame; #[pyclass] @@ -55,11 +55,15 @@ impl PyFileOptions { } #[getter] fn with_columns(&self, py: Python<'_>) -> PyResult { - Ok(self - .inner - .with_columns - .as_ref() - .map_or_else(|| py.None(), |cols| cols.to_object(py))) + Ok(self.inner.with_columns.as_ref().map_or_else( + || py.None(), + |cols| { + cols.iter() + .map(|x| x.as_str()) + .collect::>() + .to_object(py) + }, + )) } #[getter] fn cache(&self, _py: Python<'_>) -> PyResult { @@ -71,7 +75,7 @@ impl PyFileOptions { .inner .row_index .as_ref() - .map_or_else(|| py.None(), |n| (n.name.as_ref(), n.offset).to_object(py))) + .map_or_else(|| py.None(), |n| (n.name.as_str(), n.offset).to_object(py))) } #[getter] fn rechunk(&self, _py: Python<'_>) -> PyResult { @@ -270,10 +274,15 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { .scan_fn .as_ref() .map_or_else(|| py.None(), |s| s.0.clone()), - options - .with_columns - .as_ref() - .map_or_else(|| py.None(), |cols| cols.to_object(py)), + options.with_columns.as_ref().map_or_else( + || py.None(), + |cols| { + cols.iter() + .map(|x| x.as_str()) + .collect::>() + .to_object(py) + }, + ), python_src, match &options.predicate { PythonPredicate::None => py.None(), @@ -308,7 +317,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { )) }, IR::Scan { - paths, + sources, file_info: _, hive_parts: _, predicate, @@ -316,7 +325,10 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { scan_type, file_options, } => Scan { - paths: paths.to_object(py), + paths: sources + .into_paths() + .ok_or_else(|| PyNotImplementedError::new_err("scan with BytesIO"))? + .to_object(py), // TODO: file info file_info: py.None(), predicate: predicate.as_ref().map(|e| e.into()), @@ -469,10 +481,11 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { JoinType::Cross => "cross", JoinType::Semi => "leftsemi", JoinType::Anti => "leftanti", + JoinType::IEJoin(_) => return Err(PyNotImplementedError::new_err("IEJoin")), }, options.args.join_nulls, options.args.slice, - options.args.suffix.clone(), + options.args.suffix.as_deref(), options.args.coalesce.coalesce(&options.args.how), ) .to_object(py), @@ -507,10 +520,15 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { UniqueKeepStrategy::None => "none", UniqueKeepStrategy::Any => "any", }, - options - .subset - .as_ref() - .map_or_else(|| py.None(), |f| f.to_object(py)), + options.subset.as_ref().map_or_else( + || py.None(), + |f| { + f.iter() + .map(|s| s.as_ref()) + .collect::>() + .to_object(py) + }, + ), options.maintain_order, options.slice, ) @@ -520,15 +538,10 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { IR::MapFunction { input, function } => MapFunction { input: input.0, function: match function { - FunctionNode::OpaquePython { - function: _, - schema: _, - predicate_pd: _, - projection_pd: _, - streamable: _, - validate_output: _, - } => return Err(PyNotImplementedError::new_err("opaque python mapfunction")), - FunctionNode::Opaque { + FunctionIR::OpaquePython(_) => { + return Err(PyNotImplementedError::new_err("opaque python mapfunction")) + }, + FunctionIR::Opaque { function: _, schema: _, predicate_pd: _, @@ -536,22 +549,22 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { streamable: _, fmt_str: _, } => return Err(PyNotImplementedError::new_err("opaque rust mapfunction")), - FunctionNode::Pipeline { + FunctionIR::Pipeline { function: _, schema: _, original: _, } => return Err(PyNotImplementedError::new_err("pipeline mapfunction")), - FunctionNode::Unnest { columns } => ( + FunctionIR::Unnest { columns } => ( "unnest", columns.iter().map(|s| s.to_string()).collect::>(), ) .to_object(py), - FunctionNode::Rechunk => ("rechunk",).to_object(py), + FunctionIR::Rechunk => ("rechunk",).to_object(py), #[cfg(feature = "merge_sorted")] - FunctionNode::MergeSorted { column } => { + FunctionIR::MergeSorted { column } => { ("merge_sorted", column.to_string()).to_object(py) }, - FunctionNode::Rename { + FunctionIR::Rename { existing, new, swapping, @@ -563,12 +576,12 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { *swapping, ) .to_object(py), - FunctionNode::Explode { columns, schema: _ } => ( + FunctionIR::Explode { columns, schema: _ } => ( "explode", columns.iter().map(|s| s.to_string()).collect::>(), ) .to_object(py), - FunctionNode::Unpivot { args, schema: _ } => ( + FunctionIR::Unpivot { args, schema: _ } => ( "unpivot", args.index.iter().map(|s| s.as_str()).collect::>(), args.on.iter().map(|s| s.as_str()).collect::>(), @@ -580,13 +593,13 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { .map_or_else(|| py.None(), |s| s.as_str().to_object(py)), ) .to_object(py), - FunctionNode::RowIndex { + FunctionIR::RowIndex { name, schema: _, offset, } => ("row_index", name.to_string(), offset.unwrap_or(0)).to_object(py), - FunctionNode::Count { - paths: _, + FunctionIR::FastCount { + sources: _, scan_type: _, alias: _, } => return Err(PyNotImplementedError::new_err("function count")), diff --git a/py-polars/src/lazygroupby.rs b/crates/polars-python/src/lazygroupby.rs similarity index 98% rename from py-polars/src/lazygroupby.rs rename to crates/polars-python/src/lazygroupby.rs index 255bb34917f9..52df635efb53 100644 --- a/py-polars/src/lazygroupby.rs +++ b/crates/polars-python/src/lazygroupby.rs @@ -43,7 +43,7 @@ impl PyLazyGroupBy { let schema = match schema { Some(schema) => Arc::new(schema.0), None => LazyFrame::from(lgb.logical_plan.clone()) - .schema() + .collect_schema() .map_err(PyPolarsErr::from)?, }; diff --git a/crates/polars-python/src/lib.rs b/crates/polars-python/src/lib.rs new file mode 100644 index 000000000000..d696823cb527 --- /dev/null +++ b/crates/polars-python/src/lib.rs @@ -0,0 +1,42 @@ +#![allow(clippy::nonstandard_macro_braces)] // Needed because clippy does not understand proc macro of PyO3 +#![allow(clippy::transmute_undefined_repr)] +#![allow(non_local_definitions)] +#![allow(clippy::too_many_arguments)] // Python functions can have many arguments due to default arguments +#![allow(clippy::disallowed_types)] + +#[cfg(feature = "csv")] +pub mod batched_csv; +#[cfg(feature = "polars_cloud")] +pub mod cloud; +pub mod conversion; +pub mod dataframe; +pub mod datatypes; +pub mod error; +pub mod exceptions; +pub mod expr; +pub mod file; +#[cfg(feature = "pymethods")] +pub mod functions; +pub mod gil_once_cell; +pub mod interop; +pub mod lazyframe; +pub mod lazygroupby; +pub mod map; + +#[cfg(feature = "object")] +pub mod object; +#[cfg(feature = "object")] +pub mod on_startup; +pub mod prelude; +pub mod py_modules; +pub mod series; +#[cfg(feature = "sql")] +pub mod sql; +pub mod utils; + +use crate::conversion::Wrap; +use crate::dataframe::PyDataFrame; +use crate::expr::PyExpr; +use crate::lazyframe::PyLazyFrame; +use crate::lazygroupby::PyLazyGroupBy; +use crate::series::PySeries; diff --git a/py-polars/src/map/dataframe.rs b/crates/polars-python/src/map/dataframe.rs similarity index 88% rename from py-polars/src/map/dataframe.rs rename to crates/polars-python/src/map/dataframe.rs index d50adb7404e1..5be2216b0898 100644 --- a/py-polars/src/map/dataframe.rs +++ b/crates/polars-python/src/map/dataframe.rs @@ -8,15 +8,14 @@ use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple}; use super::*; use crate::PyDataFrame; +/// Create iterators for all the Series in the DataFrame. fn get_iters(df: &DataFrame) -> Vec { df.get_columns().iter().map(|s| s.iter()).collect() } -fn get_iters_skip(df: &DataFrame, skip: usize) -> Vec> { - df.get_columns() - .iter() - .map(|s| s.iter().skip(skip)) - .collect() +/// Create iterators for all the Series in the DataFrame, skipping the first `n` rows. +fn get_iters_skip(df: &DataFrame, n: usize) -> Vec> { + df.get_columns().iter().map(|s| s.iter().skip(n)).collect() } // the return type is Union[PySeries, PyDataFrame] and a boolean indicating if it is a dataframe or not @@ -168,10 +167,16 @@ where { let skip = usize::from(first_value.is_some()); if init_null_count == df.height() { - ChunkedArray::full_null("map", df.height()) + ChunkedArray::full_null(PlSmallStr::from_static("map"), df.height()) } else { let iter = apply_iter(df, py, lambda, init_null_count, skip); - iterator_to_primitive(iter, init_null_count, first_value, "map", df.height()) + iterator_to_primitive( + iter, + init_null_count, + first_value, + PlSmallStr::from_static("map"), + df.height(), + ) } } @@ -185,10 +190,16 @@ pub fn apply_lambda_with_bool_out_type<'a>( ) -> ChunkedArray { let skip = usize::from(first_value.is_some()); if init_null_count == df.height() { - ChunkedArray::full_null("map", df.height()) + ChunkedArray::full_null(PlSmallStr::from_static("map"), df.height()) } else { let iter = apply_iter(df, py, lambda, init_null_count, skip); - iterator_to_bool(iter, init_null_count, first_value, "map", df.height()) + iterator_to_bool( + iter, + init_null_count, + first_value, + PlSmallStr::from_static("map"), + df.height(), + ) } } @@ -202,10 +213,16 @@ pub fn apply_lambda_with_string_out_type<'a>( ) -> StringChunked { let skip = usize::from(first_value.is_some()); if init_null_count == df.height() { - ChunkedArray::full_null("map", df.height()) + ChunkedArray::full_null(PlSmallStr::from_static("map"), df.height()) } else { let iter = apply_iter::(df, py, lambda, init_null_count, skip); - iterator_to_string(iter, init_null_count, first_value, "map", df.height()) + iterator_to_string( + iter, + init_null_count, + first_value, + PlSmallStr::from_static("map"), + df.height(), + ) } } @@ -220,7 +237,10 @@ pub fn apply_lambda_with_list_out_type<'a>( ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == df.height() { - Ok(ChunkedArray::full_null("map", df.height())) + Ok(ChunkedArray::full_null( + PlSmallStr::from_static("map"), + df.height(), + )) } else { let mut iters = get_iters_skip(df, init_null_count + skip); let iter = ((init_null_count + skip)..df.height()).map(|_| { @@ -240,7 +260,14 @@ pub fn apply_lambda_with_list_out_type<'a>( Err(e) => panic!("python function failed {e}"), } }); - iterator_to_list(dt, iter, init_null_count, first_value, "map", df.height()) + iterator_to_list( + dt, + iter, + init_null_count, + first_value, + PlSmallStr::from_static("map"), + df.height(), + ) } } diff --git a/py-polars/src/map/lazy.rs b/crates/polars-python/src/map/lazy.rs similarity index 98% rename from py-polars/src/map/lazy.rs rename to crates/polars-python/src/map/lazy.rs index 759f1d25f443..f7edcbe3facb 100644 --- a/py-polars/src/map/lazy.rs +++ b/crates/polars-python/src/map/lazy.rs @@ -194,7 +194,7 @@ pub fn map_mul( let output_map = GetOutput::map_field(move |fld| { Ok(match output_type { - Some(ref dt) => Field::new(fld.name(), dt.0.clone()), + Some(ref dt) => Field::new(fld.name().clone(), dt.0.clone()), None => fld.clone(), }) }); diff --git a/py-polars/src/map/mod.rs b/crates/polars-python/src/map/mod.rs similarity index 92% rename from py-polars/src/map/mod.rs rename to crates/polars-python/src/map/mod.rs index db21681a04a8..8f6ed1518fe8 100644 --- a/py-polars/src/map/mod.rs +++ b/crates/polars-python/src/map/mod.rs @@ -9,10 +9,10 @@ use polars::prelude::*; use polars_core::export::rayon::prelude::*; use polars_core::utils::CustomIterTools; use polars_core::POOL; +use polars_utils::pl_str::PlSmallStr; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::PyDict; -use smartstring::alias::String as SmartString; use crate::error::PyPolarsErr; use crate::prelude::ObjectValue; @@ -35,7 +35,7 @@ fn iterator_to_struct<'a>( it: impl Iterator>>, init_null_count: usize, first_value: AnyValue<'a>, - name: &str, + name: PlSmallStr, capacity: usize, ) -> PyResult { let (vals, flds) = match &first_value { @@ -54,11 +54,11 @@ fn iterator_to_struct<'a>( // [ a values ] // [ b values ] // ] - let mut struct_fields: BTreeMap> = BTreeMap::new(); + let mut struct_fields: BTreeMap> = BTreeMap::new(); // As a BTreeMap sorts its keys, we also need to track the original // order of the field names. - let mut field_names_ordered: Vec = Vec::with_capacity(flds.len()); + let mut field_names_ordered: Vec = Vec::with_capacity(flds.len()); // Use the first value and the known null count to initialize the buffers // if we find a new key later on, we make a new entry in the BTree. @@ -96,7 +96,7 @@ fn iterator_to_struct<'a>( let mut buf = Vec::with_capacity(capacity); buf.extend((0..init_null_count + current_len).map(|_| AnyValue::Null)); buf.push(item.0); - let key: SmartString = (&*key).into(); + let key: PlSmallStr = (&*key).into(); field_names_ordered.push(key.clone()); struct_fields.insert(key, buf); }; @@ -118,7 +118,7 @@ fn iterator_to_struct<'a>( let fields = POOL.install(|| { field_names_ordered .par_iter() - .map(|name| Series::new(name, struct_fields.get(name).unwrap())) + .map(|name| Series::new(name.clone(), struct_fields.get(name).unwrap())) .collect::>() }); @@ -132,7 +132,7 @@ fn iterator_to_primitive( it: impl Iterator>, init_null_count: usize, first_value: Option, - name: &str, + name: PlSmallStr, capacity: usize, ) -> ChunkedArray where @@ -164,7 +164,7 @@ fn iterator_to_bool( it: impl Iterator>, init_null_count: usize, first_value: Option, - name: &str, + name: PlSmallStr, capacity: usize, ) -> ChunkedArray { // SAFETY: we know the iterators len. @@ -194,7 +194,7 @@ fn iterator_to_object( it: impl Iterator>, init_null_count: usize, first_value: Option, - name: &str, + name: PlSmallStr, capacity: usize, ) -> ObjectChunked { // SAFETY: we know the iterators len. @@ -223,7 +223,7 @@ fn iterator_to_string>( it: impl Iterator>, init_null_count: usize, first_value: Option, - name: &str, + name: PlSmallStr, capacity: usize, ) -> StringChunked { // SAFETY: we know the iterators len. @@ -252,7 +252,7 @@ fn iterator_to_list( it: impl Iterator>, init_null_count: usize, first_value: Option<&Series>, - name: &str, + name: PlSmallStr, capacity: usize, ) -> PyResult { let mut builder = @@ -260,16 +260,18 @@ fn iterator_to_list( for _ in 0..init_null_count { builder.append_null() } - builder - .append_opt_series(first_value) - .map_err(PyPolarsErr::from)?; + if first_value.is_some() { + builder + .append_opt_series(first_value) + .map_err(PyPolarsErr::from)?; + } for opt_val in it { match opt_val { None => builder.append_null(), Some(s) => { if s.len() == 0 && s.dtype() != dt { builder - .append_series(&Series::full_null("", 0, dt)) + .append_series(&Series::full_null(PlSmallStr::EMPTY, 0, dt)) .unwrap() } else { builder.append_series(&s).map_err(PyPolarsErr::from)? diff --git a/py-polars/src/map/series.rs b/crates/polars-python/src/map/series.rs similarity index 89% rename from py-polars/src/map/series.rs rename to crates/polars-python/src/map/series.rs index 9ec530002429..7a0d7f0887e9 100644 --- a/py-polars/src/map/series.rs +++ b/crates/polars-python/src/map/series.rs @@ -39,23 +39,29 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( let series = py_pyseries.extract::().unwrap().series; let dt = series.dtype(); applyer - .apply_lambda_with_list_out_type(py, lambda.to_object(py), null_count, &series, dt) + .apply_lambda_with_list_out_type( + py, + lambda.to_object(py), + null_count, + Some(&series), + dt, + ) .map(|ca| ca.into_series().into()) } else if out.is_instance_of::() || out.is_instance_of::() { let series = SERIES.call1(py, (out,))?; let py_pyseries = series.getattr(py, "_s").unwrap(); let series = py_pyseries.extract::(py).unwrap().series; - // Empty dtype is incorrect, use AnyValues. - if series.is_empty() { + let dt = series.dtype(); + + // Null dtype may be incorrect, fall back to AnyValues logic. + if dt.is_nested_null() { let av = out.extract::>()?; return applyer .apply_extract_any_values(py, lambda, null_count, av.0) .map(|s| s.into()); } - let dt = series.dtype(); - // make a new python function that is: // def new_lambda(lambda: Callable): // pl.Series(lambda(value)) @@ -63,13 +69,14 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( let new_lambda = PyCFunction::new_closure_bound(py, None, None, move |args, _kwargs| { Python::with_gil(|py| { let out = lambda_owned.call1(py, args)?; + // check if Series, if not, call series constructor on it SERIES.call1(py, (out,)) }) })? .to_object(py); let result = applyer - .apply_lambda_with_list_out_type(py, new_lambda, null_count, &series, dt) + .apply_lambda_with_list_out_type(py, new_lambda, null_count, Some(&series), dt) .map(|ca| ca.into_series().into()); match result { Ok(out) => Ok(out), @@ -172,7 +179,7 @@ pub trait ApplyLambda<'a> { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult; @@ -248,7 +255,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { null_count += 1 } } - Ok(Self::full_null(self.name(), self.len()) + Ok(Self::full_null(self.name().clone(), self.len()) .into_series() .into()) } @@ -266,13 +273,25 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { .into_no_null_iter() .skip(init_null_count + skip) .map(|val| call_lambda(py, lambda, val).ok()); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } else { let it = self .into_iter() .skip(init_null_count + skip) .map(|opt_val| opt_val.and_then(|val| call_lambda(py, lambda, val).ok())); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } } @@ -289,7 +308,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -299,7 +318,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -313,7 +332,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -328,7 +347,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -338,7 +357,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -352,7 +371,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -367,7 +386,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -380,7 +399,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -394,7 +413,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -405,13 +424,13 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let lambda = lambda.bind(py); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -422,8 +441,8 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } else { @@ -437,8 +456,8 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -475,7 +494,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { }); avs.extend(iter); } - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -488,7 +507,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { ) -> PyResult> { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -499,7 +518,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -513,7 +532,7 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -541,7 +560,7 @@ where null_count += 1 } } - Ok(Self::full_null(self.name(), self.len()) + Ok(Self::full_null(self.name().clone(), self.len()) .into_series() .into()) } @@ -559,13 +578,25 @@ where .into_no_null_iter() .skip(init_null_count + skip) .map(|val| call_lambda(py, lambda, val).ok()); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } else { let it = self .into_iter() .skip(init_null_count + skip) .map(|opt_val| opt_val.and_then(|val| call_lambda(py, lambda, val).ok())); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } } @@ -582,7 +613,7 @@ where { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -592,7 +623,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -606,7 +637,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -621,7 +652,7 @@ where ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -631,7 +662,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -645,7 +676,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -660,7 +691,7 @@ where ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -671,7 +702,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -685,7 +716,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -696,13 +727,13 @@ where py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let lambda = lambda.bind(py); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -713,8 +744,8 @@ where dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } else { @@ -728,8 +759,8 @@ where dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -766,7 +797,7 @@ where }); avs.extend(iter); } - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -779,7 +810,7 @@ where ) -> PyResult> { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -790,7 +821,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -804,7 +835,7 @@ where it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -827,7 +858,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { null_count += 1 } } - Ok(Self::full_null(self.name(), self.len()) + Ok(Self::full_null(self.name().clone(), self.len()) .into_series() .into()) } @@ -845,13 +876,25 @@ impl<'a> ApplyLambda<'a> for StringChunked { .into_no_null_iter() .skip(init_null_count + skip) .map(|val| call_lambda(py, lambda, val).ok()); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } else { let it = self .into_iter() .skip(init_null_count + skip) .map(|opt_val| opt_val.and_then(|val| call_lambda(py, lambda, val).ok())); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } } @@ -868,7 +911,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -878,7 +921,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -892,7 +935,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -907,7 +950,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -917,7 +960,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -931,7 +974,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -946,7 +989,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -957,7 +1000,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -971,7 +1014,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -981,13 +1024,13 @@ impl<'a> ApplyLambda<'a> for StringChunked { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let lambda = lambda.bind(py); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -998,8 +1041,8 @@ impl<'a> ApplyLambda<'a> for StringChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } else { @@ -1013,8 +1056,8 @@ impl<'a> ApplyLambda<'a> for StringChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -1051,7 +1094,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { }); avs.extend(iter); } - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -1064,7 +1107,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { ) -> PyResult> { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1075,7 +1118,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1089,7 +1132,7 @@ impl<'a> ApplyLambda<'a> for StringChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1150,7 +1193,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { null_count += 1 } } - Ok(Self::full_null(self.name(), self.len()) + Ok(Self::full_null(self.name().clone(), self.len()) .into_series() .into()) } @@ -1180,7 +1223,13 @@ impl<'a> ApplyLambda<'a> for ListChunked { .unwrap(); call_lambda(py, lambda, python_series_wrapper).ok() }); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } else { let it = self .into_iter() @@ -1198,7 +1247,13 @@ impl<'a> ApplyLambda<'a> for ListChunked { call_lambda(py, lambda, python_series_wrapper).ok() }) }); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } } @@ -1216,7 +1271,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1236,7 +1291,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1260,7 +1315,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1276,7 +1331,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1296,7 +1351,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1320,7 +1375,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1338,7 +1393,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1359,7 +1414,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1383,7 +1438,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1393,14 +1448,14 @@ impl<'a> ApplyLambda<'a> for ListChunked { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; let lambda = lambda.bind(py); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1411,8 +1466,8 @@ impl<'a> ApplyLambda<'a> for ListChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } else { @@ -1424,8 +1479,8 @@ impl<'a> ApplyLambda<'a> for ListChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -1473,7 +1528,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { .map(call_with_value); avs.extend(iter); } - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -1487,7 +1542,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1508,7 +1563,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1532,7 +1587,7 @@ impl<'a> ApplyLambda<'a> for ListChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1565,7 +1620,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { null_count += 1 } } - Ok(Self::full_null(self.name(), self.len()) + Ok(Self::full_null(self.name().clone(), self.len()) .into_series() .into()) } @@ -1595,7 +1650,13 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { .unwrap(); call_lambda(py, lambda, python_series_wrapper).ok() }); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } else { let it = self .into_iter() @@ -1613,7 +1674,13 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { call_lambda(py, lambda, python_series_wrapper).ok() }) }); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } } @@ -1631,7 +1698,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1651,7 +1718,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1675,7 +1742,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1691,7 +1758,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1711,7 +1778,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1735,7 +1802,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1753,7 +1820,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1774,7 +1841,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1798,7 +1865,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1808,14 +1875,14 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; let lambda = lambda.bind(py); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1826,8 +1893,8 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } else { @@ -1839,8 +1906,8 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -1888,7 +1955,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { .map(call_with_value); avs.extend(iter); } - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -1902,7 +1969,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { let skip = usize::from(first_value.is_some()); let pypolars = PyModule::import_bound(py, "polars")?; if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -1923,7 +1990,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -1947,7 +2014,7 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -1971,7 +2038,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { null_count += 1 } } - Ok(Self::full_null(self.name(), self.len()) + Ok(Self::full_null(self.name().clone(), self.len()) .into_series() .into()) } @@ -1991,7 +2058,13 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { let out = lambda.call1((object_value.map(|v| &v.inner),)).unwrap(); Some(out) }); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } fn apply_lambda_with_primitive_out_type( @@ -2007,7 +2080,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -2017,7 +2090,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -2031,7 +2104,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2046,7 +2119,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -2056,7 +2129,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -2070,7 +2143,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2085,7 +2158,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { ) -> PyResult { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -2096,7 +2169,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -2110,7 +2183,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2121,13 +2194,13 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let lambda = lambda.bind(py); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -2138,8 +2211,8 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } else { @@ -2153,8 +2226,8 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -2191,7 +2264,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { }); avs.extend(iter); } - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -2204,7 +2277,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { ) -> PyResult> { let skip = usize::from(first_value.is_some()); if init_null_count == self.len() { - Ok(ChunkedArray::full_null(self.name(), self.len())) + Ok(ChunkedArray::full_null(self.name().clone(), self.len())) } else if !self.has_nulls() { let it = self .into_no_null_iter() @@ -2215,7 +2288,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } else { @@ -2229,7 +2302,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2269,7 +2342,13 @@ impl<'a> ApplyLambda<'a> for StructChunked { let out = lambda.call1((Wrap(val),)).unwrap(); Some(out) }); - iterator_to_struct(it, init_null_count, first_value, self.name(), self.len()) + iterator_to_struct( + it, + init_null_count, + first_value, + self.name().clone(), + self.len(), + ) } fn apply_lambda_with_primitive_out_type( @@ -2292,7 +2371,7 @@ impl<'a> ApplyLambda<'a> for StructChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2313,7 +2392,7 @@ impl<'a> ApplyLambda<'a> for StructChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2334,7 +2413,7 @@ impl<'a> ApplyLambda<'a> for StructChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } @@ -2343,10 +2422,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { py: Python, lambda: PyObject, init_null_count: usize, - first_value: &Series, + first_value: Option<&Series>, dt: &DataType, ) -> PyResult { - let skip = 1; + let skip = usize::from(first_value.is_some()); let lambda = lambda.bind(py); let it = iter_struct(self) .skip(init_null_count + skip) @@ -2355,8 +2434,8 @@ impl<'a> ApplyLambda<'a> for StructChunked { dt, it, init_null_count, - Some(first_value), - self.name(), + first_value, + self.name().clone(), self.len(), ) } @@ -2379,7 +2458,7 @@ impl<'a> ApplyLambda<'a> for StructChunked { }); avs.extend(iter); - Ok(Series::new(self.name(), &avs)) + Ok(Series::new(self.name().clone(), &avs)) } #[cfg(feature = "object")] @@ -2399,7 +2478,7 @@ impl<'a> ApplyLambda<'a> for StructChunked { it, init_null_count, first_value, - self.name(), + self.name().clone(), self.len(), )) } diff --git a/py-polars/src/object.rs b/crates/polars-python/src/object.rs similarity index 100% rename from py-polars/src/object.rs rename to crates/polars-python/src/object.rs diff --git a/py-polars/src/on_startup.rs b/crates/polars-python/src/on_startup.rs similarity index 97% rename from py-polars/src/on_startup.rs rename to crates/polars-python/src/on_startup.rs index bca2da58fa2c..3f08f71740b5 100644 --- a/py-polars/src/on_startup.rs +++ b/crates/polars-python/src/on_startup.rs @@ -65,8 +65,7 @@ fn warning_function(msg: &str, warning: PolarsWarning) { }); } -#[pyfunction] -pub fn __register_startup_deps() { +pub fn register_startup_deps() { set_polars_allow_extension(true); if !registry::is_object_builder_registered() { // Stack frames can get really large in debug mode. @@ -77,7 +76,7 @@ pub fn __register_startup_deps() { } // register object type builder - let object_builder = Box::new(|name: &str, capacity: usize| { + let object_builder = Box::new(|name: PlSmallStr, capacity: usize| { Box::new(ObjectChunkedBuilder::::new(name, capacity)) as Box }); diff --git a/py-polars/src/prelude.rs b/crates/polars-python/src/prelude.rs similarity index 100% rename from py-polars/src/prelude.rs rename to crates/polars-python/src/prelude.rs diff --git a/py-polars/src/py_modules.rs b/crates/polars-python/src/py_modules.rs similarity index 100% rename from py-polars/src/py_modules.rs rename to crates/polars-python/src/py_modules.rs diff --git a/py-polars/src/series/aggregation.rs b/crates/polars-python/src/series/aggregation.rs similarity index 98% rename from py-polars/src/series/aggregation.rs rename to crates/polars-python/src/series/aggregation.rs index d8a5364afe3b..ac324794564c 100644 --- a/py-polars/src/series/aggregation.rs +++ b/crates/polars-python/src/series/aggregation.rs @@ -1,9 +1,10 @@ +use polars::prelude::*; use pyo3::prelude::*; use DataType::*; +use super::PySeries; +use crate::conversion::Wrap; use crate::error::PyPolarsErr; -use crate::prelude::*; -use crate::PySeries; #[pymethods] impl PySeries { diff --git a/py-polars/src/series/arithmetic.rs b/crates/polars-python/src/series/arithmetic.rs similarity index 99% rename from py-polars/src/series/arithmetic.rs rename to crates/polars-python/src/series/arithmetic.rs index e99b4d8634c1..c5483aced1e7 100644 --- a/py-polars/src/series/arithmetic.rs +++ b/crates/polars-python/src/series/arithmetic.rs @@ -1,8 +1,8 @@ +use polars::prelude::*; use pyo3::prelude::*; +use super::PySeries; use crate::error::PyPolarsErr; -use crate::prelude::*; -use crate::PySeries; #[pymethods] impl PySeries { diff --git a/py-polars/src/series/buffers.rs b/crates/polars-python/src/series/buffers.rs similarity index 94% rename from py-polars/src/series/buffers.rs rename to crates/polars-python/src/series/buffers.rs index 02595018df0d..49610fd3cf42 100644 --- a/py-polars/src/series/buffers.rs +++ b/crates/polars-python/src/series/buffers.rs @@ -18,9 +18,15 @@ use polars::export::arrow::bitmap::Bitmap; use polars::export::arrow::buffer::Buffer; use polars::export::arrow::offset::OffsetsBuffer; use polars::export::arrow::types::NativeType; +use polars::prelude::*; +use polars_core::{with_match_physical_numeric_polars_type, with_match_physical_numeric_type}; use pyo3::exceptions::PyTypeError; +use pyo3::prelude::*; -use super::*; +use super::{PySeries, ToSeries}; +use crate::conversion::Wrap; +use crate::error::PyPolarsErr; +use crate::raise_err; struct BufferInfo { pointer: usize, @@ -103,7 +109,7 @@ fn get_buffers_from_primitive( .iter() .map(|arr| arr.with_validity(None)) .collect::>(); - let values = Series::try_from((s.name(), chunks)) + let values = Series::try_from((s.name().clone(), chunks)) .map_err(PyPolarsErr::from)? .into(); @@ -145,7 +151,7 @@ fn get_string_bytes(arr: &Utf8Array) -> PyResult { let values_arr = PrimitiveArray::::try_new(ArrowDataType::UInt8, values_buffer.clone(), None) .map_err(PyPolarsErr::from)?; - let values = Series::from_arrow("", values_arr.to_boxed()) + let values = Series::from_arrow(PlSmallStr::EMPTY, values_arr.to_boxed()) .map_err(PyPolarsErr::from)? .into(); Ok(values) @@ -156,7 +162,7 @@ fn get_string_offsets(arr: &Utf8Array) -> PyResult { let offsets_arr = PrimitiveArray::::try_new(ArrowDataType::Int64, offsets_buffer.clone(), None) .map_err(PyPolarsErr::from)?; - let offsets = Series::from_arrow("", offsets_arr.to_boxed()) + let offsets = Series::from_arrow(PlSmallStr::EMPTY, offsets_arr.to_boxed()) .map_err(PyPolarsErr::from)? .into(); Ok(offsets) @@ -197,7 +203,9 @@ impl PySeries { }, }; - let s = Series::from_arrow("", arr_boxed).unwrap().into(); + let s = Series::from_arrow(PlSmallStr::EMPTY, arr_boxed) + .unwrap() + .into(); Ok(s) } } @@ -349,13 +357,13 @@ fn from_buffers_num_impl( validity: Option, ) -> PyResult { let arr = PrimitiveArray::new(T::PRIMITIVE.into(), data, validity); - let s_result = Series::from_arrow("", arr.to_boxed()); + let s_result = Series::from_arrow(PlSmallStr::EMPTY, arr.to_boxed()); let s = s_result.map_err(PyPolarsErr::from)?; Ok(s) } fn from_buffers_bool_impl(data: Bitmap, validity: Option) -> PyResult { let arr = BooleanArray::new(ArrowDataType::Boolean, data, validity); - let s_result = Series::from_arrow("", arr.to_boxed()); + let s_result = Series::from_arrow(PlSmallStr::EMPTY, arr.to_boxed()); let s = s_result.map_err(PyPolarsErr::from)?; Ok(s) } @@ -370,7 +378,7 @@ fn from_buffers_string_impl( let arr = Utf8Array::new(ArrowDataType::LargeUtf8, offsets, data, validity); // This is not zero-copy - let s_result = Series::from_arrow("", arr.to_boxed()); + let s_result = Series::from_arrow(PlSmallStr::EMPTY, arr.to_boxed()); let s = s_result.map_err(PyPolarsErr::from)?; Ok(s) diff --git a/py-polars/src/series/c_interface.rs b/crates/polars-python/src/series/c_interface.rs similarity index 81% rename from py-polars/src/series/c_interface.rs rename to crates/polars-python/src/series/c_interface.rs index e3a7807765a7..e978efbf58b7 100644 --- a/py-polars/src/series/c_interface.rs +++ b/crates/polars-python/src/series/c_interface.rs @@ -1,7 +1,10 @@ use polars::export::arrow; +use polars::prelude::*; use pyo3::ffi::Py_uintptr_t; +use pyo3::prelude::*; -use super::*; +use super::PySeries; +use crate::error::PyPolarsErr; // Import arrow data directly without requiring pyarrow (used in pyo3-polars) #[pymethods] @@ -22,11 +25,11 @@ impl PySeries { let schema = &*schema_ptr; let field = arrow::ffi::import_field_from_c(schema).unwrap(); - arrow::ffi::import_array_from_c(array, field.data_type).unwrap() + arrow::ffi::import_array_from_c(array, field.dtype).unwrap() }) .collect::>(); - let s = Series::try_from((name, chunks)).map_err(PyPolarsErr::from)?; + let s = Series::try_new(name.into(), chunks).map_err(PyPolarsErr::from)?; Ok(s.into()) } @@ -51,7 +54,11 @@ unsafe fn export_chunk( let out_ptr = out_ptr as *mut arrow::ffi::ArrowArray; *out_ptr = c_array; - let field = ArrowField::new(s.name(), s.dtype().to_arrow(CompatLevel::newest()), true); + let field = ArrowField::new( + s.name().clone(), + s.dtype().to_arrow(CompatLevel::newest()), + true, + ); let c_schema = arrow::ffi::export_field_to_c(&field); let out_schema_ptr = out_schema_ptr as *mut arrow::ffi::ArrowSchema; diff --git a/py-polars/src/series/comparison.rs b/crates/polars-python/src/series/comparison.rs similarity index 97% rename from py-polars/src/series/comparison.rs rename to crates/polars-python/src/series/comparison.rs index 8ebd85021463..7064edb7698a 100644 --- a/py-polars/src/series/comparison.rs +++ b/crates/polars-python/src/series/comparison.rs @@ -227,7 +227,10 @@ macro_rules! impl_decimal { #[pymethods] impl PySeries { fn $name(&self, rhs: PyDecimal) -> PyResult { - let rhs = Series::new("decimal", &[AnyValue::Decimal(rhs.0, rhs.1)]); + let rhs = Series::new( + PlSmallStr::from_static("decimal"), + &[AnyValue::Decimal(rhs.0, rhs.1)], + ); let s = self.series.$method(&rhs).map_err(PyPolarsErr::from)?; Ok(s.into_series().into()) } diff --git a/py-polars/src/series/construction.rs b/crates/polars-python/src/series/construction.rs similarity index 88% rename from py-polars/src/series/construction.rs rename to crates/polars-python/src/series/construction.rs index c8361e7bb837..7482d3a96c13 100644 --- a/py-polars/src/series/construction.rs +++ b/crates/polars-python/src/series/construction.rs @@ -52,7 +52,9 @@ fn mmap_numpy_array( let vals = unsafe { array.as_slice().unwrap() }; let arr = unsafe { arrow::ffi::mmap::slice_and_owner(vals, array.to_object(py)) }; - Series::from_arrow(name, arr.to_boxed()).unwrap().into() + Series::from_arrow(name.into(), arr.to_boxed()) + .unwrap() + .into() } #[pymethods] @@ -61,7 +63,7 @@ impl PySeries { fn new_bool(py: Python, name: &str, array: &Bound>, _strict: bool) -> Self { let array = array.readonly(); let vals = array.as_slice().unwrap(); - py.allow_threads(|| Series::new(name, vals).into()) + py.allow_threads(|| Series::new(name.into(), vals).into()) } #[staticmethod] @@ -73,7 +75,7 @@ impl PySeries { .iter() .map(|&val| if f32::is_nan(val) { None } else { Some(val) }) .collect_trusted(); - ca.with_name(name).into_series().into() + ca.with_name(name.into()).into_series().into() } else { mmap_numpy_array(py, name, array) } @@ -88,7 +90,7 @@ impl PySeries { .iter() .map(|&val| if f64::is_nan(val) { None } else { Some(val) }) .collect_trusted(); - ca.with_name(name).into_series().into() + ca.with_name(name.into()).into_series().into() } else { mmap_numpy_array(py, name, array) } @@ -100,7 +102,7 @@ impl PySeries { #[staticmethod] fn new_opt_bool(name: &str, values: &Bound, _strict: bool) -> PyResult { let len = values.len()?; - let mut builder = BooleanChunkedBuilder::new(name, len); + let mut builder = BooleanChunkedBuilder::new(name.into(), len); for res in values.iter()? { let value = res?; @@ -125,7 +127,7 @@ where T::Native: FromPyObject<'a>, { let len = values.len()?; - let mut builder = PrimitiveChunkedBuilder::::new(name, len); + let mut builder = PrimitiveChunkedBuilder::::new(name.into(), len); for res in values.iter()? { let value = res?; @@ -175,7 +177,7 @@ impl PySeries { .map(|v| py_object_to_any_value(&(v?).as_borrowed(), strict)) .collect::>>(); let result = any_values_result.and_then(|avs| { - let s = Series::from_any_values(name, avs.as_slice(), strict).map_err(|e| { + let s = Series::from_any_values(name.into(), avs.as_slice(), strict).map_err(|e| { PyTypeError::new_err(format!( "{e}\n\nHint: Try setting `strict=False` to allow passing data with mixed types." )) @@ -213,19 +215,20 @@ impl PySeries { .iter()? .map(|v| py_object_to_any_value(&(v?).as_borrowed(), strict)) .collect::>>()?; - let s = Series::from_any_values_and_dtype(name, any_values.as_slice(), &dtype.0, strict) - .map_err(|e| { - PyTypeError::new_err(format!( - "{e}\n\nHint: Try setting `strict=False` to allow passing data with mixed types." - )) - })?; + let s = + Series::from_any_values_and_dtype(name.into(), any_values.as_slice(), &dtype.0, strict) + .map_err(|e| { + PyTypeError::new_err(format!( + "{e}\n\nHint: Try setting `strict=False` to allow passing data with mixed types." + )) + })?; Ok(s.into()) } #[staticmethod] fn new_str(name: &str, values: &Bound, _strict: bool) -> PyResult { let len = values.len()?; - let mut builder = StringChunkedBuilder::new(name, len); + let mut builder = StringChunkedBuilder::new(name.into(), len); for res in values.iter()? { let value = res?; @@ -245,7 +248,7 @@ impl PySeries { #[staticmethod] fn new_binary(name: &str, values: &Bound, _strict: bool) -> PyResult { let len = values.len()?; - let mut builder = BinaryChunkedBuilder::new(name, len); + let mut builder = BinaryChunkedBuilder::new(name.into(), len); for res in values.iter()? { let value = res?; @@ -277,7 +280,7 @@ impl PySeries { )); } } - Ok(Series::new(name, series).into()) + Ok(Series::new(name.into(), series).into()) } #[staticmethod] @@ -303,7 +306,7 @@ impl PySeries { }); // Object builder must be registered. This is done on import. let ca = ObjectChunked::::new_from_vec_and_validity( - name, + name.into(), values, validity.into(), ); @@ -317,19 +320,19 @@ impl PySeries { #[staticmethod] fn new_null(name: &str, values: &Bound, _strict: bool) -> PyResult { let len = values.len()?; - Ok(Series::new_null(name, len).into()) + Ok(Series::new_null(name.into(), len).into()) } #[staticmethod] fn from_arrow(name: &str, array: &Bound) -> PyResult { let arr = array_to_rust(array)?; - match arr.data_type() { + match arr.dtype() { ArrowDataType::LargeList(_) => { let array = arr.as_any().downcast_ref::().unwrap(); let fast_explode = array.offsets().as_slice().windows(2).all(|w| w[0] != w[1]); - let mut out = ListChunked::with_chunk(name, array.clone()); + let mut out = ListChunked::with_chunk(name.into(), array.clone()); if fast_explode { out.set_fast_explode() } @@ -337,7 +340,7 @@ impl PySeries { }, _ => { let series: Series = - std::convert::TryFrom::try_from((name, arr)).map_err(PyPolarsErr::from)?; + Series::try_new(name.into(), arr).map_err(PyPolarsErr::from)?; Ok(series.into()) }, } diff --git a/py-polars/src/series/export.rs b/crates/polars-python/src/series/export.rs similarity index 99% rename from py-polars/src/series/export.rs rename to crates/polars-python/src/series/export.rs index 901050ad74be..886b6114427a 100644 --- a/py-polars/src/series/export.rs +++ b/crates/polars-python/src/series/export.rs @@ -2,9 +2,10 @@ use polars_core::prelude::*; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyList}; +use super::PySeries; +use crate::interop; use crate::interop::arrow::to_py::series_to_stream; use crate::prelude::*; -use crate::{interop, PySeries}; #[pymethods] impl PySeries { diff --git a/py-polars/src/series/mod.rs b/crates/polars-python/src/series/general.rs similarity index 64% rename from py-polars/src/series/mod.rs rename to crates/polars-python/src/series/general.rs index 153a2dc12df3..359f39df6291 100644 --- a/py-polars/src/series/mod.rs +++ b/crates/polars-python/src/series/general.rs @@ -1,72 +1,18 @@ -mod aggregation; -mod arithmetic; -mod buffers; -mod c_interface; -mod comparison; -mod construction; -mod export; -mod import; -mod numpy_ufunc; -mod scatter; - use std::io::Cursor; use polars_core::chunked_array::cast::CastOptions; use polars_core::series::IsSorted; use polars_core::utils::flatten::flatten_series; -use polars_core::{with_match_physical_numeric_polars_type, with_match_physical_numeric_type}; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::Python; +use super::PySeries; use crate::dataframe::PyDataFrame; use crate::error::PyPolarsErr; -use crate::map::series::{call_lambda_and_extract, ApplyLambda}; use crate::prelude::*; use crate::py_modules::POLARS; -use crate::{apply_method_all_arrow_series2, raise_err}; - -#[pyclass] -#[repr(transparent)] -#[derive(Clone)] -pub struct PySeries { - pub series: Series, -} - -impl From for PySeries { - fn from(series: Series) -> Self { - PySeries { series } - } -} - -impl PySeries { - pub(crate) fn new(series: Series) -> Self { - PySeries { series } - } -} - -pub(crate) trait ToSeries { - fn to_series(self) -> Vec; -} - -impl ToSeries for Vec { - fn to_series(self) -> Vec { - // SAFETY: repr is transparent. - unsafe { std::mem::transmute(self) } - } -} - -pub(crate) trait ToPySeries { - fn to_pyseries(self) -> Vec; -} - -impl ToPySeries for Vec { - fn to_pyseries(self) -> Vec { - // SAFETY: repr is transparent. - unsafe { std::mem::transmute(self) } - } -} #[pymethods] impl PySeries { @@ -162,7 +108,7 @@ impl PySeries { } } - fn rechunk(&mut self, in_place: bool) -> Option { + pub fn rechunk(&mut self, in_place: bool) -> Option { let series = self.series.rechunk(); if in_place { self.series = series; @@ -242,12 +188,12 @@ impl PySeries { self.series.chunk_lengths().collect() } - fn name(&self) -> &str { - self.series.name() + pub fn name(&self) -> &str { + self.series.name().as_str() } fn rename(&mut self, name: &str) { - self.series.rename(name); + self.series.rename(name.into()); } fn dtype(&self, py: Python) -> PyObject { @@ -351,7 +297,8 @@ impl PySeries { Ok(format!("{:?}", self.series)) } - fn len(&self) -> usize { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { self.series.len() } @@ -366,238 +313,6 @@ impl PySeries { self.series.clone().into() } - #[pyo3(signature = (lambda, output_type, skip_nulls))] - fn apply_lambda( - &self, - lambda: &Bound, - output_type: Option>, - skip_nulls: bool, - ) -> PyResult { - let series = &self.series; - - if output_type.is_none() { - polars_warn!( - MapWithoutReturnDtypeWarning, - "Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \ - Specify `return_dtype` to silence this warning.") - } - - if skip_nulls && (series.null_count() == series.len()) { - if let Some(output_type) = output_type { - return Ok(Series::full_null(series.name(), series.len(), &output_type.0).into()); - } - let msg = "The output type of the 'apply' function cannot be determined.\n\ - The function was never called because 'skip_nulls=True' and all values are null.\n\ - Consider setting 'skip_nulls=False' or setting the 'return_dtype'."; - raise_err!(msg, ComputeError) - } - - let output_type = output_type.map(|dt| dt.0); - - macro_rules! dispatch_apply { - ($self:expr, $method:ident, $($args:expr),*) => { - match $self.dtype() { - #[cfg(feature = "object")] - DataType::Object(_, _) => { - let ca = $self.0.unpack::>().unwrap(); - ca.$method($($args),*) - }, - _ => { - apply_method_all_arrow_series2!( - $self, - $method, - $($args),* - ) - } - - } - } - - } - - Python::with_gil(|py| { - if matches!( - self.series.dtype(), - DataType::Datetime(_, _) - | DataType::Date - | DataType::Duration(_) - | DataType::Categorical(_, _) - | DataType::Enum(_, _) - | DataType::Binary - | DataType::Array(_, _) - | DataType::Time - ) || !skip_nulls - { - let mut avs = Vec::with_capacity(self.series.len()); - let s = self.series.rechunk(); - let iter = s.iter().map(|av| match (skip_nulls, av) { - (true, AnyValue::Null) => AnyValue::Null, - (_, av) => { - let input = Wrap(av); - call_lambda_and_extract::<_, Wrap>(py, lambda, input) - .unwrap() - .0 - }, - }); - avs.extend(iter); - return Ok(Series::new(self.name(), &avs).into()); - } - - let out = match output_type { - Some(DataType::Int8) => { - let ca: Int8Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Int16) => { - let ca: Int16Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Int32) => { - let ca: Int32Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Int64) => { - let ca: Int64Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt8) => { - let ca: UInt8Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt16) => { - let ca: UInt16Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt32) => { - let ca: UInt32Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::UInt64) => { - let ca: UInt64Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Float32) => { - let ca: Float32Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Float64) => { - let ca: Float64Chunked = dispatch_apply!( - series, - apply_lambda_with_primitive_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::Boolean) => { - let ca: BooleanChunked = dispatch_apply!( - series, - apply_lambda_with_bool_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - Some(DataType::String) => { - let ca = dispatch_apply!( - series, - apply_lambda_with_string_out_type, - py, - lambda, - 0, - None - )?; - - ca.into_series() - }, - #[cfg(feature = "object")] - Some(DataType::Object(_, _)) => { - let ca = dispatch_apply!( - series, - apply_lambda_with_object_out_type, - py, - lambda, - 0, - None - )?; - ca.into_series() - }, - None => return dispatch_apply!(series, apply_lambda_unknown, py, lambda), - - _ => return dispatch_apply!(series, apply_lambda_unknown, py, lambda), - }; - - Ok(out.into()) - }) - } - fn zip_with(&self, mask: &PySeries, other: &PySeries) -> PyResult { let mask = mask.series.bool().map_err(PyPolarsErr::from)?; let s = self @@ -771,7 +486,7 @@ impl PySeries { ) -> PyResult { let out = self .series - .value_counts(sort, parallel, name, normalize) + .value_counts(sort, parallel, name.into(), normalize) .map_err(PyPolarsErr::from)?; Ok(out.into()) } @@ -861,6 +576,7 @@ impl_get!(get_duration, duration, i64); #[cfg(test)] mod test { use super::*; + use crate::series::ToSeries; #[test] fn transmute_to_series() { diff --git a/py-polars/src/series/import.rs b/crates/polars-python/src/series/import.rs similarity index 94% rename from py-polars/src/series/import.rs rename to crates/polars-python/src/series/import.rs index 4ea467ea3d58..3bc65f28aaf7 100644 --- a/py-polars/src/series/import.rs +++ b/crates/polars-python/src/series/import.rs @@ -4,11 +4,12 @@ use polars::export::arrow::ffi; use polars::export::arrow::ffi::{ ArrowArray, ArrowArrayStream, ArrowArrayStreamReader, ArrowSchema, }; +use polars::prelude::*; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple, PyType}; -use super::*; +use super::PySeries; /// Validate PyCapsule has provided name fn validate_pycapsule_name(capsule: &Bound, expected_name: &str) -> PyResult<()> { @@ -69,7 +70,7 @@ pub(crate) fn import_array_pycapsules( let array_ptr = std::ptr::replace(array_capsule.pointer() as _, ArrowArray::empty()); let field = ffi::import_field_from_c(schema_ptr).unwrap(); - let array = ffi::import_array_from_c(array_ptr, field.data_type().clone()).unwrap(); + let array = ffi::import_array_from_c(array_ptr, field.dtype().clone()).unwrap(); (field, array) }; @@ -112,8 +113,8 @@ pub(crate) fn import_stream_pycapsule(capsule: &Bound) -> PyResult, + return_dtype: Option>, + skip_nulls: bool, + ) -> PyResult { + let series = &self.series; + + if return_dtype.is_none() { + polars_warn!( + MapWithoutReturnDtypeWarning, + "Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \ + Specify `return_dtype` to silence this warning.") + } + + if skip_nulls && (series.null_count() == series.len()) { + if let Some(return_dtype) = return_dtype { + return Ok( + Series::full_null(series.name().clone(), series.len(), &return_dtype.0).into(), + ); + } + let msg = "The output type of the 'map_elements' function cannot be determined.\n\ + The function was never called because 'skip_nulls=True' and all values are null.\n\ + Consider setting 'skip_nulls=False' or setting the 'return_dtype'."; + raise_err!(msg, ComputeError) + } + + let return_dtype = return_dtype.map(|dt| dt.0); + + macro_rules! dispatch_apply { + ($self:expr, $method:ident, $($args:expr),*) => { + match $self.dtype() { + #[cfg(feature = "object")] + DataType::Object(_, _) => { + let ca = $self.0.unpack::>().unwrap(); + ca.$method($($args),*) + }, + _ => { + apply_method_all_arrow_series2!( + $self, + $method, + $($args),* + ) + } + + } + } + + } + + Python::with_gil(|py| { + if matches!( + self.series.dtype(), + DataType::Datetime(_, _) + | DataType::Date + | DataType::Duration(_) + | DataType::Categorical(_, _) + | DataType::Enum(_, _) + | DataType::Binary + | DataType::Array(_, _) + | DataType::Time + ) || !skip_nulls + { + let mut avs = Vec::with_capacity(self.series.len()); + let s = self.series.rechunk(); + let iter = s.iter().map(|av| match (skip_nulls, av) { + (true, AnyValue::Null) => AnyValue::Null, + (_, av) => { + let input = Wrap(av); + call_lambda_and_extract::<_, Wrap>(py, function, input) + .unwrap() + .0 + }, + }); + avs.extend(iter); + return Ok(Series::new(self.series.name().clone(), &avs).into()); + } + + let out = match return_dtype { + Some(DataType::Int8) => { + let ca: Int8Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Int16) => { + let ca: Int16Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Int32) => { + let ca: Int32Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Int64) => { + let ca: Int64Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt8) => { + let ca: UInt8Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt16) => { + let ca: UInt16Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt32) => { + let ca: UInt32Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::UInt64) => { + let ca: UInt64Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Float32) => { + let ca: Float32Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Float64) => { + let ca: Float64Chunked = dispatch_apply!( + series, + apply_lambda_with_primitive_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::Boolean) => { + let ca: BooleanChunked = dispatch_apply!( + series, + apply_lambda_with_bool_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + Some(DataType::String) => { + let ca = dispatch_apply!( + series, + apply_lambda_with_string_out_type, + py, + function, + 0, + None + )?; + + ca.into_series() + }, + Some(DataType::List(inner)) => { + // Make sure the function returns a Series of the correct data type. + let function_owned = function.to_object(py); + let dtype_py = Wrap((*inner).clone()).to_object(py); + let function_wrapped = + PyCFunction::new_closure_bound(py, None, None, move |args, _kwargs| { + Python::with_gil(|py| { + let out = function_owned.call1(py, args)?; + SERIES.call1(py, ("", out, dtype_py.clone())) + }) + })? + .to_object(py); + + let ca = dispatch_apply!( + series, + apply_lambda_with_list_out_type, + py, + function_wrapped, + 0, + None, + inner.as_ref() + )?; + + ca.into_series() + }, + #[cfg(feature = "object")] + Some(DataType::Object(_, _)) => { + let ca = dispatch_apply!( + series, + apply_lambda_with_object_out_type, + py, + function, + 0, + None + )?; + ca.into_series() + }, + None => return dispatch_apply!(series, apply_lambda_unknown, py, function), + + _ => return dispatch_apply!(series, apply_lambda_unknown, py, function), + }; + + Ok(out.into()) + }) + } +} diff --git a/crates/polars-python/src/series/mod.rs b/crates/polars-python/src/series/mod.rs new file mode 100644 index 000000000000..1b4542b06c5a --- /dev/null +++ b/crates/polars-python/src/series/mod.rs @@ -0,0 +1,68 @@ +#[cfg(feature = "pymethods")] +mod aggregation; +#[cfg(feature = "pymethods")] +mod arithmetic; +#[cfg(feature = "pymethods")] +mod buffers; +#[cfg(feature = "pymethods")] +mod c_interface; +#[cfg(feature = "pymethods")] +mod comparison; +#[cfg(feature = "pymethods")] +mod construction; +#[cfg(feature = "pymethods")] +mod export; +#[cfg(feature = "pymethods")] +mod general; +#[cfg(feature = "pymethods")] +mod import; +#[cfg(feature = "pymethods")] +mod map; +#[cfg(feature = "pymethods")] +mod numpy_ufunc; +#[cfg(feature = "pymethods")] +mod scatter; + +use polars::prelude::Series; +use pyo3::pyclass; + +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PySeries { + pub series: Series, +} + +impl From for PySeries { + fn from(series: Series) -> Self { + PySeries { series } + } +} + +impl PySeries { + pub(crate) fn new(series: Series) -> Self { + PySeries { series } + } +} + +pub(crate) trait ToSeries { + fn to_series(self) -> Vec; +} + +impl ToSeries for Vec { + fn to_series(self) -> Vec { + // SAFETY: repr is transparent. + unsafe { std::mem::transmute(self) } + } +} + +pub(crate) trait ToPySeries { + fn to_pyseries(self) -> Vec; +} + +impl ToPySeries for Vec { + fn to_pyseries(self) -> Vec { + // SAFETY: repr is transparent. + unsafe { std::mem::transmute(self) } + } +} diff --git a/py-polars/src/series/numpy_ufunc.rs b/crates/polars-python/src/series/numpy_ufunc.rs similarity index 95% rename from py-polars/src/series/numpy_ufunc.rs rename to crates/polars-python/src/series/numpy_ufunc.rs index 94aa42ffa18c..10d765c3fc25 100644 --- a/py-polars/src/series/numpy_ufunc.rs +++ b/crates/polars-python/src/series/numpy_ufunc.rs @@ -9,7 +9,7 @@ use polars_core::utils::arrow::types::NativeType; use pyo3::prelude::*; use pyo3::types::{PyNone, PyTuple}; -use crate::series::PySeries; +use super::PySeries; /// Create an empty numpy array arrows 64 byte alignment /// @@ -109,8 +109,11 @@ macro_rules! impl_ufuncs { assert!(get_refcnt(&out_array) <= 3); let validity = self.series.chunks()[0].validity().cloned(); - let ca = - ChunkedArray::<$type>::from_vec_validity(self.name(), av, validity); + let ca = ChunkedArray::<$type>::from_vec_validity( + self.series.name().clone(), + av, + validity, + ); PySeries::new(ca.into_series()) }, Err(e) => { diff --git a/py-polars/src/series/scatter.rs b/crates/polars-python/src/series/scatter.rs similarity index 99% rename from py-polars/src/series/scatter.rs rename to crates/polars-python/src/series/scatter.rs index a1ac9ac4c979..97df60ef205b 100644 --- a/py-polars/src/series/scatter.rs +++ b/crates/polars-python/src/series/scatter.rs @@ -2,8 +2,8 @@ use polars::export::arrow::array::Array; use polars::prelude::*; use pyo3::prelude::*; +use super::PySeries; use crate::error::PyPolarsErr; -use crate::PySeries; #[pymethods] impl PySeries { diff --git a/py-polars/src/sql.rs b/crates/polars-python/src/sql.rs similarity index 100% rename from py-polars/src/sql.rs rename to crates/polars-python/src/sql.rs diff --git a/py-polars/src/utils.rs b/crates/polars-python/src/utils.rs similarity index 100% rename from py-polars/src/utils.rs rename to crates/polars-python/src/utils.rs diff --git a/crates/polars-row/src/decode.rs b/crates/polars-row/src/decode.rs index 180cf2ad00e8..858ce3f55fcf 100644 --- a/crates/polars-row/src/decode.rs +++ b/crates/polars-row/src/decode.rs @@ -11,13 +11,13 @@ use crate::variable::{decode_binary, decode_binview}; pub unsafe fn decode_rows_from_binary<'a>( arr: &'a BinaryArray, fields: &[EncodingField], - data_types: &[ArrowDataType], + dtypes: &[ArrowDataType], rows: &mut Vec<&'a [u8]>, ) -> Vec { assert_eq!(arr.null_count(), 0); rows.clear(); rows.extend(arr.values_iter()); - decode_rows(rows, fields, data_types) + decode_rows(rows, fields, dtypes) } /// Decode `rows` into a arrow format @@ -28,18 +28,18 @@ pub unsafe fn decode_rows( // the rows will be updated while the data is decoded rows: &mut [&[u8]], fields: &[EncodingField], - data_types: &[ArrowDataType], + dtypes: &[ArrowDataType], ) -> Vec { - assert_eq!(fields.len(), data_types.len()); - data_types + assert_eq!(fields.len(), dtypes.len()); + dtypes .iter() .zip(fields) - .map(|(data_type, field)| decode(rows, field, data_type)) + .map(|(dtype, field)| decode(rows, field, dtype)) .collect() } -unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, data_type: &ArrowDataType) -> ArrayRef { - match data_type { +unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, dtype: &ArrowDataType) -> ArrayRef { + match dtype { ArrowDataType::Null => NullArray::new(ArrowDataType::Null, rows.len()).to_boxed(), ArrowDataType::Boolean => decode_bool(rows, field).to_boxed(), ArrowDataType::BinaryView | ArrowDataType::LargeBinary => { @@ -62,9 +62,9 @@ unsafe fn decode(rows: &mut [&[u8]], field: &EncodingField, data_type: &ArrowDat ArrowDataType::Struct(fields) => { let values = fields .iter() - .map(|struct_fld| decode(rows, field, struct_fld.data_type())) + .map(|struct_fld| decode(rows, field, struct_fld.dtype())) .collect(); - StructArray::new(data_type.clone(), values, None).to_boxed() + StructArray::new(dtype.clone(), values, None).to_boxed() }, dt => { with_match_arrow_primitive_type!(dt, |$T| { diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 00e888c0e9b0..a415e2fe1915 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -91,10 +91,10 @@ impl Encoder { } } - fn data_type(&self) -> &ArrowDataType { + fn dtype(&self) -> &ArrowDataType { match self { - Encoder::List { original, .. } => original.data_type(), - Encoder::Leaf(arr) => arr.data_type(), + Encoder::List { original, .. } => original.dtype(), + Encoder::Leaf(arr) => arr.dtype(), } } @@ -102,7 +102,7 @@ impl Encoder { match self { Encoder::Leaf(arr) => { matches!( - arr.data_type(), + arr.dtype(), ArrowDataType::BinaryView | ArrowDataType::Dictionary(_, _, _) | ArrowDataType::LargeBinary @@ -115,7 +115,7 @@ impl Encoder { fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &EncodingField) -> usize { let mut added = 0; - match arr.data_type() { + match arr.dtype() { ArrowDataType::Struct(_) => { let arr = arr.as_any().downcast_ref::().unwrap(); for value_arr in arr.values() { @@ -164,7 +164,7 @@ pub fn convert_columns_amortized<'a, I: IntoIterator>( assert_eq!(fields.size_hint().0, columns.len()); if columns.iter().any(|arr| { matches!( - arr.data_type(), + arr.dtype(), ArrowDataType::Struct(_) | ArrowDataType::Utf8View | ArrowDataType::LargeList(_) ) }) { @@ -233,7 +233,7 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE crate::variable::encode_iter(iter, out, &EncodingField::new_unsorted()) }, Encoder::Leaf(array) => { - match array.data_type() { + match array.dtype() { ArrowDataType::Boolean => { let array = array.as_any().downcast_ref::().unwrap(); crate::fixed::encode_iter(array.into_iter(), out, field); @@ -271,9 +271,9 @@ unsafe fn encode_array(encoder: &Encoder, field: &EncodingField, out: &mut RowsE } } -pub fn encoded_size(data_type: &ArrowDataType) -> usize { +pub fn encoded_size(dtype: &ArrowDataType) -> usize { use ArrowDataType::*; - match data_type { + match dtype { UInt8 => u8::ENCODED_LEN, UInt16 => u16::ENCODED_LEN, UInt32 => u32::ENCODED_LEN, @@ -310,7 +310,7 @@ fn allocate_rows_buf( if enc.is_variable() { 0 } else { - encoded_size(enc.data_type()) + encoded_size(enc.dtype()) } }) .sum(); @@ -338,9 +338,9 @@ fn allocate_rows_buf( // encode the rows instead of only setting the length. // This needs a bit refactoring, might require allocation and encoding to be in // the same function. - if let ArrowDataType::LargeList(inner) = original.data_type() { + if let ArrowDataType::LargeList(inner) = original.dtype() { assert!( - !matches!(inner.data_type, ArrowDataType::LargeList(_)), + !matches!(inner.dtype, ArrowDataType::LargeList(_)), "should not be nested" ) } @@ -390,7 +390,7 @@ fn allocate_rows_buf( processed_count += 1; }, Encoder::Leaf(array) => { - match array.data_type() { + match array.dtype() { ArrowDataType::BinaryView => { let array = array.as_any().downcast_ref::().unwrap(); if processed_count == 0 { @@ -483,10 +483,7 @@ fn allocate_rows_buf( values.reserve(current_offset); current_offset } else { - let row_size: usize = columns - .iter() - .map(|arr| encoded_size(arr.data_type())) - .sum(); + let row_size: usize = columns.iter().map(|arr| encoded_size(arr.dtype())).sum(); let n_bytes = num_rows * row_size; values.clear(); values.reserve(n_bytes); @@ -625,7 +622,7 @@ mod test { let values = Utf8ViewArray::from_slice_values([ "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", ]); - let dtype = LargeListArray::default_datatype(values.data_type().clone()); + let dtype = LargeListArray::default_datatype(values.dtype().clone()); let array = LargeListArray::new( dtype, Offsets::::try_from(vec![0i64, 1, 4, 7, 7, 9, 10]) diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index f9bdc4394b08..7932420d1577 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -219,7 +219,7 @@ pub(super) unsafe fn decode_primitive( where T::Encoded: FromSlice, { - let data_type: ArrowDataType = T::PRIMITIVE.into(); + let dtype: ArrowDataType = T::PRIMITIVE.into(); let mut has_nulls = false; let null_sentinel = get_null_sentinel(field); @@ -252,7 +252,7 @@ where let increment_len = T::ENCODED_LEN; increment_row_counter(rows, increment_len); - PrimitiveArray::new(data_type, values.into(), validity) + PrimitiveArray::new(dtype, values.into(), validity) } pub(super) unsafe fn decode_bool(rows: &mut [&[u8]], field: &EncodingField) -> BooleanArray { diff --git a/crates/polars-schema/Cargo.toml b/crates/polars-schema/Cargo.toml new file mode 100644 index 000000000000..f928d9f36925 --- /dev/null +++ b/crates/polars-schema/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "polars-schema" +version = { workspace = true } +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +description = "Private crate for schema utilities for the Polars DataFrame library" + +[dependencies] +indexmap = { workspace = true } +polars-error = { workspace = true } +polars-utils = { workspace = true } +serde = { workspace = true, optional = true } + +[build-dependencies] +version_check = { workspace = true } + +[features] +nightly = [] +serde = ["dep:serde", "serde/derive"] diff --git a/crates/polars-schema/LICENSE b/crates/polars-schema/LICENSE new file mode 120000 index 000000000000..30cff7403da0 --- /dev/null +++ b/crates/polars-schema/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/crates/polars-schema/README.md b/crates/polars-schema/README.md new file mode 100644 index 000000000000..6d68ee41675a --- /dev/null +++ b/crates/polars-schema/README.md @@ -0,0 +1,5 @@ +# polars-schema + +`polars-schema` is an **internal sub-crate** of the [Polars](https://crates.io/crates/polars) library, supplying private schema utility functions. + +**Important Note**: This crate is **not intended for external usage**. Please refer to the main [Polars crate](https://crates.io/crates/polars) for intended usage. diff --git a/crates/polars-schema/build.rs b/crates/polars-schema/build.rs new file mode 100644 index 000000000000..3e4ab64620ac --- /dev/null +++ b/crates/polars-schema/build.rs @@ -0,0 +1,7 @@ +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + let channel = version_check::Channel::read().unwrap(); + if channel.is_nightly() { + println!("cargo:rustc-cfg=feature=\"nightly\""); + } +} diff --git a/crates/polars-schema/src/lib.rs b/crates/polars-schema/src/lib.rs new file mode 100644 index 000000000000..902d6a57d6f9 --- /dev/null +++ b/crates/polars-schema/src/lib.rs @@ -0,0 +1,2 @@ +pub mod schema; +pub use schema::Schema; diff --git a/crates/polars-schema/src/schema.rs b/crates/polars-schema/src/schema.rs new file mode 100644 index 000000000000..3f03bdffde24 --- /dev/null +++ b/crates/polars-schema/src/schema.rs @@ -0,0 +1,456 @@ +use core::fmt::{Debug, Formatter}; +use core::hash::{Hash, Hasher}; + +use indexmap::map::MutableKeys; +use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult}; +use polars_utils::aliases::{InitHashMaps, PlIndexMap}; +use polars_utils::pl_str::PlSmallStr; + +#[derive(Clone, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Schema { + fields: PlIndexMap, +} + +impl Eq for Schema {} + +impl Schema { + pub fn with_capacity(capacity: usize) -> Self { + let fields = PlIndexMap::with_capacity(capacity); + Self { fields } + } + + /// Reserve `additional` memory spaces in the schema. + pub fn reserve(&mut self, additional: usize) { + self.fields.reserve(additional); + } + + /// The number of fields in the schema. + #[inline] + pub fn len(&self) -> usize { + self.fields.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// Rename field `old` to `new`, and return the (owned) old name. + /// + /// If `old` is not present in the schema, the schema is not modified and `None` is returned. Otherwise the schema + /// is updated and `Some(old_name)` is returned. + pub fn rename(&mut self, old: &str, new: PlSmallStr) -> Option { + // Remove `old`, get the corresponding index and dtype, and move the last item in the map to that position + let (old_index, old_name, dtype) = self.fields.swap_remove_full(old)?; + // Insert the same dtype under the new name at the end of the map and store that index + let (new_index, _) = self.fields.insert_full(new, dtype); + // Swap the two indices to move the originally last element back to the end and to move the new element back to + // its original position + self.fields.swap_indices(old_index, new_index); + + Some(old_name) + } + + pub fn insert(&mut self, key: PlSmallStr, value: D) -> Option { + self.fields.insert(key, value) + } + + /// Insert a field with `name` and `dtype` at the given `index` into this schema. + /// + /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is + /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the + /// end of the schema). + /// + /// For a non-mutating version that clones the schema, see [`new_inserting_at_index`][Self::new_inserting_at_index]. + /// + /// Runtime: **O(n)** where `n` is the number of fields in the schema. + /// + /// Returns: + /// - If index is out of bounds, `Err(PolarsError)` + /// - Else if `name` was already in the schema, `Ok(Some(old_dtype))` + /// - Else `Ok(None)` + pub fn insert_at_index( + &mut self, + mut index: usize, + name: PlSmallStr, + dtype: D, + ) -> PolarsResult> { + polars_ensure!( + index <= self.len(), + OutOfBounds: + "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", + index, + self.len() + ); + + let (old_index, old_dtype) = self.fields.insert_full(name, dtype); + + // If we're moving an existing field, one-past-the-end will actually be out of bounds. Also, self.len() won't + // have changed after inserting, so `index == self.len()` is the same as it was before inserting. + if old_dtype.is_some() && index == self.len() { + index -= 1; + } + self.fields.move_index(old_index, index); + Ok(old_dtype) + } + + /// Get a reference to the dtype of the field named `name`, or `None` if the field doesn't exist. + pub fn get(&self, name: &str) -> Option<&D> { + self.fields.get(name) + } + + /// Get a reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist. + pub fn try_get(&self, name: &str) -> PolarsResult<&D> { + self.get(name) + .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) + } + + /// Get a mutable reference to the dtype of the field named `name`, or `Err(PolarsErr)` if the field doesn't exist. + pub fn try_get_mut(&mut self, name: &str) -> PolarsResult<&mut D> { + self.fields + .get_mut(name) + .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) + } + + /// Return all data about the field named `name`: its index in the schema, its name, and its dtype. + /// + /// Returns `Some((index, &name, &dtype))` if the field exists, `None` if it doesn't. + pub fn get_full(&self, name: &str) -> Option<(usize, &PlSmallStr, &D)> { + self.fields.get_full(name) + } + + /// Return all data about the field named `name`: its index in the schema, its name, and its dtype. + /// + /// Returns `Ok((index, &name, &dtype))` if the field exists, `Err(PolarsErr)` if it doesn't. + pub fn try_get_full(&self, name: &str) -> PolarsResult<(usize, &PlSmallStr, &D)> { + self.fields + .get_full(name) + .ok_or_else(|| polars_err!(SchemaFieldNotFound: "{}", name)) + } + + /// Get references to the name and dtype of the field at `index`. + /// + /// If `index` is inbounds, returns `Some((&name, &dtype))`, else `None`. See + /// [`get_at_index_mut`][Self::get_at_index_mut] for a mutable version. + pub fn get_at_index(&self, index: usize) -> Option<(&PlSmallStr, &D)> { + self.fields.get_index(index) + } + + pub fn try_get_at_index(&self, index: usize) -> PolarsResult<(&PlSmallStr, &D)> { + self.fields.get_index(index).ok_or_else(|| polars_err!(ComputeError: "index {index} out of bounds with 'schema' of len: {}", self.len())) + } + + /// Get mutable references to the name and dtype of the field at `index`. + /// + /// If `index` is inbounds, returns `Some((&mut name, &mut dtype))`, else `None`. See + /// [`get_at_index`][Self::get_at_index] for an immutable version. + pub fn get_at_index_mut(&mut self, index: usize) -> Option<(&mut PlSmallStr, &mut D)> { + self.fields.get_index_mut2(index) + } + + /// Swap-remove a field by name and, if the field existed, return its dtype. + /// + /// If the field does not exist, the schema is not modified and `None` is returned. + /// + /// This method does a `swap_remove`, which is O(1) but **changes the order of the schema**: the field named `name` + /// is replaced by the last field, which takes its position. For a slower, but order-preserving, method, use + /// [`shift_remove`][Self::shift_remove]. + pub fn remove(&mut self, name: &str) -> Option { + self.fields.swap_remove(name) + } + + /// Remove a field by name, preserving order, and, if the field existed, return its dtype. + /// + /// If the field does not exist, the schema is not modified and `None` is returned. + /// + /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a + /// faster, but not order-preserving, method, use [`remove`][Self::remove]. + pub fn shift_remove(&mut self, name: &str) -> Option { + self.fields.shift_remove(name) + } + + /// Remove a field by name, preserving order, and, if the field existed, return its dtype. + /// + /// If the field does not exist, the schema is not modified and `None` is returned. + /// + /// This method does a `shift_remove`, which preserves the order of the fields in the schema but **is O(n)**. For a + /// faster, but not order-preserving, method, use [`remove`][Self::remove]. + pub fn shift_remove_index(&mut self, index: usize) -> Option<(PlSmallStr, D)> { + self.fields.shift_remove_index(index) + } + + /// Whether the schema contains a field named `name`. + pub fn contains(&self, name: &str) -> bool { + self.get(name).is_some() + } + + /// Change the field named `name` to the given `dtype` and return the previous dtype. + /// + /// If `name` doesn't already exist in the schema, the schema is not modified and `None` is returned. Otherwise + /// returns `Some(old_dtype)`. + /// + /// This method only ever modifies an existing field and never adds a new field to the schema. To add a new field, + /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index]. + pub fn set_dtype(&mut self, name: &str, dtype: D) -> Option { + let old_dtype = self.fields.get_mut(name)?; + Some(std::mem::replace(old_dtype, dtype)) + } + + /// Change the field at the given index to the given `dtype` and return the previous dtype. + /// + /// If the index is out of bounds, the schema is not modified and `None` is returned. Otherwise returns + /// `Some(old_dtype)`. + /// + /// This method only ever modifies an existing index and never adds a new field to the schema. To add a new field, + /// use [`with_column`][Self::with_column] or [`insert_at_index`][Self::insert_at_index]. + pub fn set_dtype_at_index(&mut self, index: usize, dtype: D) -> Option { + let (_, old_dtype) = self.fields.get_index_mut(index)?; + Some(std::mem::replace(old_dtype, dtype)) + } + + /// Insert a new column in the [`Schema`]. + /// + /// If an equivalent name already exists in the schema: the name remains and + /// retains in its place in the order, its corresponding value is updated + /// with [`D`] and the older dtype is returned inside `Some(_)`. + /// + /// If no equivalent key existed in the map: the new name-dtype pair is + /// inserted, last in order, and `None` is returned. + /// + /// To enforce the index of the resulting field, use [`insert_at_index`][Self::insert_at_index]. + /// + /// Computes in **O(1)** time (amortized average). + pub fn with_column(&mut self, name: PlSmallStr, dtype: D) -> Option { + self.fields.insert(name, dtype) + } + + /// Merge `other` into `self`. + /// + /// Merging logic: + /// - Fields that occur in `self` but not `other` are unmodified + /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` + /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original + /// index + pub fn merge(&mut self, other: Self) { + self.fields.extend(other.fields) + } + + /// Iterates over the `(&name, &dtype)` pairs in this schema. + /// + /// For an owned version, use [`iter_fields`][Self::iter_fields], which clones the data to iterate owned `Field`s + pub fn iter(&self) -> impl ExactSizeIterator + '_ { + self.fields.iter() + } + + pub fn iter_mut(&mut self) -> impl ExactSizeIterator + '_ { + self.fields.iter_mut() + } + + /// Iterates over references to the names in this schema. + pub fn iter_names(&self) -> impl '_ + ExactSizeIterator { + self.fields.iter().map(|(name, _dtype)| name) + } + + pub fn iter_names_cloned(&self) -> impl '_ + ExactSizeIterator { + self.iter_names().cloned() + } + + /// Iterates over references to the dtypes in this schema. + pub fn iter_values(&self) -> impl '_ + ExactSizeIterator { + self.fields.iter().map(|(_name, dtype)| dtype) + } + + pub fn into_iter_values(self) -> impl ExactSizeIterator { + self.fields.into_values() + } + + /// Iterates over mut references to the dtypes in this schema. + pub fn iter_values_mut(&mut self) -> impl '_ + ExactSizeIterator { + self.fields.iter_mut().map(|(_name, dtype)| dtype) + } + + pub fn index_of(&self, name: &str) -> Option { + self.fields.get_index_of(name) + } + + pub fn try_index_of(&self, name: &str) -> PolarsResult { + let Some(i) = self.fields.get_index_of(name) else { + polars_bail!( + ColumnNotFound: + "unable to find column {:?}; valid columns: {:?}", + name, self.iter_names().collect::>(), + ) + }; + + Ok(i) + } +} + +impl Schema +where + D: Clone + Default, +{ + /// Create a new schema from this one, inserting a field with `name` and `dtype` at the given `index`. + /// + /// If a field named `name` already exists, it is updated with the new dtype. Regardless, the field named `name` is + /// always moved to the given index. Valid indices range from `0` (front of the schema) to `self.len()` (after the + /// end of the schema). + /// + /// For a mutating version that doesn't clone, see [`insert_at_index`][Self::insert_at_index]. + /// + /// Runtime: **O(m * n)** where `m` is the (average) length of the field names and `n` is the number of fields in + /// the schema. This method clones every field in the schema. + /// + /// Returns: `Ok(new_schema)` if `index <= self.len()`, else `Err(PolarsError)` + pub fn new_inserting_at_index( + &self, + index: usize, + name: PlSmallStr, + field: D, + ) -> PolarsResult { + polars_ensure!( + index <= self.len(), + OutOfBounds: + "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", + index, + self.len() + ); + + let mut new = Self::default(); + let mut iter = self.fields.iter().filter_map(|(fld_name, dtype)| { + (fld_name != &name).then_some((fld_name.clone(), dtype.clone())) + }); + new.fields.extend(iter.by_ref().take(index)); + new.fields.insert(name.clone(), field); + new.fields.extend(iter); + Ok(new) + } + + /// Merge borrowed `other` into `self`. + /// + /// Merging logic: + /// - Fields that occur in `self` but not `other` are unmodified + /// - Fields that occur in `other` but not `self` are appended, in order, to the end of `self` + /// - Fields that occur in both `self` and `other` are updated with the dtype from `other`, but keep their original + /// index + pub fn merge_from_ref(&mut self, other: &Self) { + self.fields.extend( + other + .iter() + .map(|(column, field)| (column.clone(), field.clone())), + ) + } + + /// Generates another schema with just the specified columns selected from this one. + pub fn try_project(&self, columns: I) -> PolarsResult + where + I: IntoIterator, + I::Item: AsRef, + { + let schema = columns + .into_iter() + .map(|c| { + let name = c.as_ref(); + let (_, name, dtype) = self + .fields + .get_full(name) + .ok_or_else(|| polars_err!(col_not_found = name))?; + PolarsResult::Ok((name.clone(), dtype.clone())) + }) + .collect::>>()?; + Ok(Self::from(schema)) + } + + pub fn try_project_indices(&self, indices: &[usize]) -> PolarsResult { + let fields = indices + .iter() + .map(|&i| { + let Some((k, v)) = self.fields.get_index(i) else { + polars_bail!( + SchemaFieldNotFound: + "projection index {} is out of bounds for schema of length {}", + i, self.fields.len() + ); + }; + + Ok((k.clone(), v.clone())) + }) + .collect::>>()?; + + Ok(Self { fields }) + } + + /// Returns a new [`Schema`] with a subset of all fields whose `predicate` + /// evaluates to true. + pub fn filter bool>(self, predicate: F) -> Self { + let fields = self + .fields + .into_iter() + .enumerate() + .filter_map(|(index, (name, d))| { + if (predicate)(index, &d) { + Some((name, d)) + } else { + None + } + }) + .collect(); + + Self { fields } + } +} + +impl Debug for Schema { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Schema:")?; + for (name, field) in self.fields.iter() { + writeln!(f, "name: {name}, field: {field:?}")?; + } + Ok(()) + } +} + +impl Hash for Schema { + fn hash(&self, state: &mut H) { + self.fields.iter().for_each(|v| v.hash(state)) + } +} + +// Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner == +// other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it. +impl PartialEq for Schema { + fn eq(&self, other: &Self) -> bool { + self.fields.len() == other.fields.len() + && self + .fields + .iter() + .zip(other.fields.iter()) + .all(|(a, b)| a == b) + } +} + +impl From> for Schema { + fn from(fields: PlIndexMap) -> Self { + Self { fields } + } +} + +impl FromIterator for Schema +where + F: Into<(PlSmallStr, D)>, +{ + fn from_iter>(iter: I) -> Self { + let fields = PlIndexMap::from_iter(iter.into_iter().map(|x| x.into())); + Self { fields } + } +} + +impl IntoIterator for Schema { + type IntoIter = as IntoIterator>::IntoIter; + type Item = (PlSmallStr, D); + + fn into_iter(self) -> Self::IntoIter { + self.fields.into_iter() + } +} diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 0c8f883daf50..c959694214ce 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -16,6 +16,7 @@ polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_ polars-ops = { workspace = true } polars-plan = { workspace = true } polars-time = { workspace = true } +polars-utils = { workspace = true } hex = { workspace = true } once_cell = { workspace = true } @@ -37,7 +38,7 @@ csv = ["polars-lazy/csv"] diagonal_concat = ["polars-lazy/diagonal_concat"] dtype-decimal = ["polars-lazy/dtype-decimal"] ipc = ["polars-lazy/ipc"] -json = ["polars-lazy/json", "polars-plan/extract_jsonpath"] +json = ["polars-lazy/json", "polars-plan/json", "polars-plan/extract_jsonpath"] list_eval = ["polars-lazy/list_eval"] parquet = ["polars-lazy/parquet"] semi_anti_join = ["polars-lazy/semi_anti_join"] diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index ab1b9a53997c..23ffb25070fa 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -8,6 +8,7 @@ use polars_lazy::prelude::*; use polars_ops::frame::JoinCoalesce; use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; +use polars_utils::format_pl_smallstr; use sqlparser::ast::{ BinaryOperator, CreateTable, Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderBy, @@ -32,10 +33,10 @@ pub struct TableInfo { } struct SelectModifiers { - exclude: PlHashSet, // SELECT * EXCLUDE - ilike: Option, // SELECT * ILIKE - rename: PlHashMap, // SELECT * RENAME - replace: Vec, // SELECT * REPLACE + exclude: PlHashSet, // SELECT * EXCLUDE + ilike: Option, // SELECT * ILIKE + rename: PlHashMap, // SELECT * RENAME + replace: Vec, // SELECT * REPLACE } impl SelectModifiers { fn matches_ilike(&self, s: &str) -> bool { @@ -47,7 +48,7 @@ impl SelectModifiers { fn renamed_cols(&self) -> Vec { self.rename .iter() - .map(|(before, after)| col(before).alias(after)) + .map(|(before, after)| col(before.clone()).alias(after.clone())) .collect() } } @@ -380,12 +381,12 @@ impl SQLContext { .join_nulls(true); let lf_schema = self.get_frame_schema(&mut lf)?; - let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm)).collect(); + let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm.clone())).collect(); let joined_tbl = match quantifier { - SetQuantifier::ByName | SetQuantifier::AllByName => join.on(lf_cols).finish(), + SetQuantifier::ByName => join.on(lf_cols).finish(), SetQuantifier::Distinct | SetQuantifier::None => { let rf_schema = self.get_frame_schema(&mut rf)?; - let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm)).collect(); + let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm.clone())).collect(); if lf_cols.len() != rf_cols.len() { polars_bail!(SQLInterface: "{} requires equal number of columns in each table (use '{} BY NAME' to combine mismatched tables)", op_name, op_name) } @@ -470,7 +471,7 @@ impl SQLContext { let plan = plan .split('\n') .collect::() - .with_name("Logical Plan"); + .with_name(PlSmallStr::from_static("Logical Plan")); let df = DataFrame::new(vec![plan])?; Ok(df.lazy()) }, @@ -480,7 +481,7 @@ impl SQLContext { // SHOW TABLES fn execute_show_tables(&mut self, _: &Statement) -> PolarsResult { - let tables = Series::new("name", self.get_tables()); + let tables = Series::new("name".into(), self.get_tables()); let df = DataFrame::new(vec![tables])?; Ok(df.lazy()) } @@ -592,7 +593,9 @@ impl SQLContext { }, )? }, - JoinOperator::CrossJoin => lf.cross_join(rf, Some(format!(":{}", r_name))), + JoinOperator::CrossJoin => { + lf.cross_join(rf, Some(format_pl_smallstr!(":{}", r_name))) + }, join_type => { polars_bail!(SQLInterface: "join type '{:?}' not currently supported", join_type) }, @@ -687,7 +690,7 @@ impl SQLContext { if matches!(&**e, Expr::Agg(_) | Expr::Len | Expr::Literal(_)) => {}, Expr::Alias(e, _) if matches!(&**e, Expr::Column(_)) => { if let Expr::Column(name) = &**e { - group_by_keys.push(col(name)); + group_by_keys.push(col(name.clone())); } }, _ => { @@ -773,7 +776,7 @@ impl SQLContext { .map(|e| { let expr = parse_sql_expr(e, self, schema.as_deref())?; if let Expr::Column(name) = expr { - Ok(name.to_string()) + Ok(name.clone()) } else { Err(polars_err!(SQLSyntax:"DISTINCT ON only supports column names")) } @@ -782,7 +785,7 @@ impl SQLContext { // DISTINCT ON has to apply the ORDER BY before the operation. lf = self.process_order_by(lf, &query.order_by, None)?; - return Ok(lf.unique_stable(Some(cols), UniqueKeepStrategy::First)); + return Ok(lf.unique_stable(Some(cols.clone()), UniqueKeepStrategy::First)); }, None => lf, }; @@ -804,7 +807,7 @@ impl SQLContext { }, SelectItem::ExprWithAlias { expr, alias } => { let expr = parse_sql_expr(expr, self, Some(schema))?; - Ok(vec![expr.alias(&alias.value)]) + Ok(vec![expr.alias(PlSmallStr::from_str(alias.value.as_str()))]) }, SelectItem::QualifiedWildcard(obj_name, wildcard_options) => self .process_qualified_wildcard( @@ -816,7 +819,7 @@ impl SQLContext { SelectItem::Wildcard(wildcard_options) => { let cols = schema .iter_names() - .map(|name| col(name)) + .map(|name| col(name.clone())) .collect::>(); self.process_wildcard_additional_options( @@ -844,8 +847,28 @@ impl SQLContext { expr: &Option, ) -> PolarsResult { if let Some(expr) = expr { - let schema = Some(self.get_frame_schema(&mut lf)?); - let mut filter_expression = parse_sql_expr(expr, self, schema.as_deref())?; + let schema = self.get_frame_schema(&mut lf)?; + + // shortcut filter evaluation if given expression is just TRUE or FALSE + let (all_true, all_false) = match expr { + SQLExpr::Value(SQLValue::Boolean(b)) => (*b, !*b), + SQLExpr::BinaryOp { left, op, right } => match (&**left, &**right, op) { + (SQLExpr::Value(a), SQLExpr::Value(b), BinaryOperator::Eq) => (a == b, a != b), + (SQLExpr::Value(a), SQLExpr::Value(b), BinaryOperator::NotEq) => { + (a != b, a == b) + }, + _ => (false, false), + }, + _ => (false, false), + }; + if all_true { + return Ok(lf); + } else if all_false { + return Ok(DataFrame::empty_with_schema(schema.as_ref()).lazy()); + } + + // ...otherwise parse and apply the filter as normal + let mut filter_expression = parse_sql_expr(expr, self, Some(schema).as_deref())?; if filter_expression.clone().meta().has_multiple_outputs() { filter_expression = all_horizontal([filter_expression])?; } @@ -980,14 +1003,14 @@ impl SQLContext { } => { if let Some(alias) = alias { let table_name = alias.name.value.clone(); - let column_names: Vec> = alias + let column_names: Vec> = alias .columns .iter() .map(|c| { if c.value.is_empty() { None } else { - Some(c.value.as_str()) + Some(PlSmallStr::from_str(c.value.as_str())) } }) .collect(); @@ -1009,8 +1032,8 @@ impl SQLContext { ); } let column_series: Vec = column_values - .iter() - .zip(column_names.iter()) + .into_iter() + .zip(column_names) .map(|(s, name)| { if let Some(name) = name { s.clone().with_name(name) @@ -1076,7 +1099,7 @@ impl SQLContext { return Ok(lf); } let schema = self.get_frame_schema(&mut lf)?; - let columns_iter = schema.iter_names().map(|e| col(e)); + let columns_iter = schema.iter_names().map(|e| col(e.clone())); let order_by = order_by.as_ref().unwrap().exprs.clone(); let mut descending = Vec::with_capacity(order_by.len()); @@ -1154,7 +1177,8 @@ impl SQLContext { .. } = expr.deref() { - projection_overrides.insert(alias.as_ref(), col(name).alias(alias)); + projection_overrides + .insert(alias.as_ref(), col(name.clone()).alias(alias.clone())); } else if !is_agg_or_window && !group_by_keys_schema.contains(alias) { projection_aliases.insert(alias.as_ref()); } @@ -1166,7 +1190,7 @@ impl SQLContext { e = (**expr).clone(); } else if let Expr::Alias(expr, name) = &e { if let Expr::Agg(AggExpr::Implode(expr)) = expr.as_ref() { - e = (**expr).clone().alias(name.as_ref()); + e = (**expr).clone().alias(name.clone()); } } aggregation_projection.push(e); @@ -1199,7 +1223,7 @@ impl SQLContext { { projection_expr.clone() } else { - col(name) + col(name.clone()) } }) .collect::>(); @@ -1317,7 +1341,8 @@ impl SQLContext { RenameSelectItem::Multiple(renames) => renames.iter().collect(), }; for rn in renames { - let (before, after) = (rn.ident.value.clone(), rn.alias.value.clone()); + let (before, after) = (rn.ident.value.as_str(), rn.alias.value.as_str()); + let (before, after) = (PlSmallStr::from_str(before), PlSmallStr::from_str(after)); if before != after { modifiers.rename.insert(before, after); } @@ -1381,8 +1406,8 @@ fn collect_compound_identifiers( right_name: &str, ) -> PolarsResult<(Vec, Vec)> { if left.len() == 2 && right.len() == 2 { - let (tbl_a, col_a) = (&left[0].value, &left[1].value); - let (tbl_b, col_b) = (&right[0].value, &right[1].value); + let (tbl_a, col_a) = (left[0].value.as_str(), left[1].value.as_str()); + let (tbl_b, col_b) = (right[0].value.as_str(), right[1].value.as_str()); // switch left/right operands if the caller has them in reverse if left_name == tbl_b || right_name == tbl_a { @@ -1399,22 +1424,25 @@ fn expand_exprs(expr: Expr, schema: &SchemaRef) -> Vec { match expr { Expr::Wildcard => schema .iter_names() - .map(|name| col(name)) + .map(|name| col(name.clone())) .collect::>(), - Expr::Column(nm) if is_regex_colname(nm.clone()) => { + Expr::Column(nm) if is_regex_colname(nm.as_str()) => { let rx = regex::Regex::new(&nm).unwrap(); schema .iter_names() .filter(|name| rx.is_match(name)) - .map(|name| col(name)) + .map(|name| col(name.clone())) .collect::>() }, - Expr::Columns(names) => names.iter().map(|name| col(name)).collect::>(), + Expr::Columns(names) => names + .iter() + .map(|name| col(name.clone())) + .collect::>(), _ => vec![expr], } } -fn is_regex_colname(nm: ColumnName) -> bool { +fn is_regex_colname(nm: &str) -> bool { nm.starts_with('^') && nm.ends_with('$') } @@ -1475,14 +1503,17 @@ fn process_join_constraint( return collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name); }, (SQLExpr::Identifier(left), SQLExpr::Identifier(right)) => { - return Ok((vec![col(&left.value)], vec![col(&right.value)])) + return Ok(( + vec![col(left.value.as_str())], + vec![col(right.value.as_str())], + )) }, _ => {}, } }; if let JoinConstraint::Using(idents) = constraint { if !idents.is_empty() { - let using: Vec = idents.iter().map(|id| col(&id.value)).collect(); + let using: Vec = idents.iter().map(|id| col(id.value.as_str())).collect(); return Ok((using.clone(), using.clone())); } }; @@ -1491,7 +1522,7 @@ fn process_join_constraint( let right_names = tbl_right.schema.iter_names().collect::>(); let on = left_names .intersection(&right_names) - .map(|name| col(name)) + .map(|&name| col(name.clone())) .collect::>(); if on.is_empty() { polars_bail!(SQLInterface: "no common columns found for NATURAL JOIN") diff --git a/crates/polars-sql/src/function_registry.rs b/crates/polars-sql/src/function_registry.rs index c85f8307af73..aa693025b072 100644 --- a/crates/polars-sql/src/function_registry.rs +++ b/crates/polars-sql/src/function_registry.rs @@ -1,4 +1,4 @@ -//! This module defines the function registry and user defined functions. +//! This module defines a FunctionRegistry for supported SQL functions and UDFs. use polars_error::{polars_bail, PolarsResult}; use polars_plan::prelude::udf::UserDefinedFunction; diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 0124a5409f7d..87b0656d171d 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -10,6 +10,7 @@ use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal use polars_plan::plans::{typed_lit, LiteralValue}; use polars_plan::prelude::LiteralValue::Null; use polars_plan::prelude::{col, cols, lit, StrptimeOptions}; +use polars_utils::pl_str::PlSmallStr; use sqlparser::ast::{ DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident, @@ -983,7 +984,7 @@ impl SQLFunctionVisitor<'_> { parse_extract_date_part( e, &DateTimeField::Custom(Ident { - value: p, + value: p.to_string(), quote_style: None, }), ) @@ -1154,11 +1155,11 @@ impl SQLFunctionVisitor<'_> { Strptime => { let args = extract_args(function)?; match args.len() { - 2 => self.visit_binary(|e, fmt| { + 2 => self.visit_binary(|e, fmt: String| { e.str().strptime( DataType::Datetime(TimeUnit::Microseconds, None), StrptimeOptions { - format: Some(fmt), + format: Some(fmt.into()), ..Default::default() }, lit("latest"), @@ -1274,37 +1275,39 @@ impl SQLFunctionVisitor<'_> { // ---- Columns => { let active_schema = self.active_schema; - self.try_visit_unary(|e: Expr| { - match e { - Expr::Literal(LiteralValue::String(pat)) => { - if "*" == pat { - polars_bail!(SQLSyntax: "COLUMNS('*') is not a valid regex; did you mean COLUMNS(*)?") - }; - let pat = match pat.as_str() { - _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(), - _ if pat.starts_with('^') => format!("{}.*$", pat), - _ if pat.ends_with('$') => format!("^.*{}", pat), - _ => format!("^.*{}.*$", pat), - }; - if let Some(active_schema) = &active_schema { - let rx = regex::Regex::new(&pat).unwrap(); - let col_names = active_schema - .iter_names() - .filter(|name| rx.is_match(name)) - .collect::>(); - - Ok(if col_names.len() == 1 { - col(col_names[0]) - } else { - cols(col_names) - }) + self.try_visit_unary(|e: Expr| match e { + Expr::Literal(LiteralValue::String(pat)) => { + if pat == "*" { + polars_bail!( + SQLSyntax: "COLUMNS('*') is not a valid regex; \ + did you mean COLUMNS(*)?" + ) + }; + let pat = match pat.as_str() { + _ if pat.starts_with('^') && pat.ends_with('$') => pat.to_string(), + _ if pat.starts_with('^') => format!("{}.*$", pat), + _ if pat.ends_with('$') => format!("^.*{}", pat), + _ => format!("^.*{}.*$", pat), + }; + if let Some(active_schema) = &active_schema { + let rx = regex::Regex::new(&pat).unwrap(); + let col_names = active_schema + .iter_names() + .filter(|name| rx.is_match(name)) + .cloned() + .collect::>(); + + Ok(if col_names.len() == 1 { + col(col_names.into_iter().next().unwrap()) } else { - Ok(col(&pat)) - } - }, - Expr::Wildcard => Ok(col("*")), - _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e), - } + cols(col_names) + }) + } else { + Ok(col(pat.as_str())) + } + }, + Expr::Wildcard => Ok(col("*")), + _ => polars_bail!(SQLSyntax: "COLUMNS expects a regex; found {:?}", e), }) }, @@ -1760,7 +1763,7 @@ impl FromSQLExpr for StrptimeOptions { match expr { SQLExpr::Value(v) => match v { SQLValue::SingleQuotedString(s) => Ok(StrptimeOptions { - format: Some(s.clone()), + format: Some(PlSmallStr::from_str(s)), ..StrptimeOptions::default() }), _ => polars_bail!(SQLInterface: "cannot parse literal {:?}", v), diff --git a/crates/polars-sql/src/keywords.rs b/crates/polars-sql/src/keywords.rs index 1442a91cd89f..990bc046aa5b 100644 --- a/crates/polars-sql/src/keywords.rs +++ b/crates/polars-sql/src/keywords.rs @@ -1,10 +1,8 @@ -//! Keywords that are supported by Polars SQL -//! -//! This is useful for syntax highlighting +//! Keywords that are supported by the Polars SQL interface. //! //! This module defines: -//! - all Polars SQL keywords [`all_keywords`] -//! - all of polars SQL functions [`all_functions`] +//! - all recognised Polars SQL keywords [`all_keywords`] +//! - all recognised Polars SQL functions [`all_functions`] use crate::functions::PolarsSQLFunctions; use crate::table_functions::PolarsTableFunctions; diff --git a/crates/polars-sql/src/lib.rs b/crates/polars-sql/src/lib.rs index a811a4cfad9b..528f21eafaf2 100644 --- a/crates/polars-sql/src/lib.rs +++ b/crates/polars-sql/src/lib.rs @@ -7,6 +7,7 @@ mod functions; pub mod keywords; mod sql_expr; mod table_functions; +mod types; pub use context::SQLContext; pub use sql_expr::sql_expr; diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 9374a5fd3229..148a7fe5735e 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -1,3 +1,11 @@ +//! Expressions that are supported by the Polars SQL interface. +//! +//! This is useful for syntax highlighting +//! +//! This module defines: +//! - all Polars SQL keywords [`all_keywords`] +//! - all of polars SQL functions [`all_functions`] + use std::fmt::Display; use std::ops::Div; @@ -9,216 +17,39 @@ use polars_plan::prelude::LiteralValue::Null; use polars_time::Duration; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; -use regex::{Regex, RegexBuilder}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "dtype-decimal")] -use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ - ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, + BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - Interval, ObjectName, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField, + Interval, Query as Subquery, SelectItem, Subscript, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; use crate::functions::SQLFunctionVisitor; +use crate::types::{ + bitstring_to_bytes_literal, is_iso_date, is_iso_datetime, is_iso_time, map_sql_dtype_to_polars, +}; use crate::SQLContext; -static DATETIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); -static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); -static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); - -fn is_iso_datetime(value: &str) -> bool { - let dtm_regex = DATETIME_LITERAL_RE.get_or_init(|| { - RegexBuilder::new( - r"^\d{4}-[01]\d-[0-3]\d[ T](?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$", - ) - .build() - .unwrap() - }); - dtm_regex.is_match(value) -} - -fn is_iso_date(value: &str) -> bool { - let dt_regex = DATE_LITERAL_RE.get_or_init(|| { - RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d$") - .build() - .unwrap() - }); - dt_regex.is_match(value) -} - -fn is_iso_time(value: &str) -> bool { - let tm_regex = TIME_LITERAL_RE.get_or_init(|| { - RegexBuilder::new(r"^(?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$") - .build() - .unwrap() - }); - tm_regex.is_match(value) -} - #[inline] #[cold] #[must_use] +/// Convert a Display-able error to PolarsError::SQLInterface pub fn to_sql_interface_err(err: impl Display) -> PolarsError { PolarsError::SQLInterface(err.to_string().into()) } -fn timeunit_from_precision(prec: &Option) -> PolarsResult { - Ok(match prec { - None => TimeUnit::Microseconds, - Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, - Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, - Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, - Some(n) => { - polars_bail!(SQLSyntax: "invalid temporal type precision (expected 1-9, found {})", n) - }, - }) -} - -pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { - Ok(match data_type { - // --------------------------------- - // array/list - // --------------------------------- - SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, _)) => { - DataType::List(Box::new(map_sql_polars_datatype(inner_type)?)) - }, - - // --------------------------------- - // binary - // --------------------------------- - SQLDataType::Bytea - | SQLDataType::Bytes(_) - | SQLDataType::Binary(_) - | SQLDataType::Blob(_) - | SQLDataType::Varbinary(_) => DataType::Binary, - - // --------------------------------- - // boolean - // --------------------------------- - SQLDataType::Boolean | SQLDataType::Bool => DataType::Boolean, - - // --------------------------------- - // signed integer - // --------------------------------- - SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, - SQLDataType::Int2(_) | SQLDataType::SmallInt(_) => DataType::Int16, - SQLDataType::Int4(_) | SQLDataType::MediumInt(_) => DataType::Int32, - SQLDataType::Int8(_) | SQLDataType::BigInt(_) => DataType::Int64, - SQLDataType::TinyInt(_) => DataType::Int8, - - // --------------------------------- - // unsigned integer: the following do not map to PostgreSQL types/syntax, but - // are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)"). - // --------------------------------- - SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below - SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, - SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, - SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32, - SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => { - DataType::UInt64 - }, - - // --------------------------------- - // float - // --------------------------------- - SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => { - DataType::Float64 - }, - SQLDataType::Float(n_bytes) => match n_bytes { - Some(n) if (1u64..=24u64).contains(n) => DataType::Float32, - Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, - Some(n) => { - polars_bail!(SQLSyntax: "unsupported `float` size (expected a value between 1 and 53, found {})", n) - }, - None => DataType::Float64, - }, - SQLDataType::Float4 | SQLDataType::Real => DataType::Float32, - - // --------------------------------- - // decimal - // --------------------------------- - #[cfg(feature = "dtype-decimal")] - SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => { - match *info { - ExactNumberInfo::PrecisionAndScale(p, s) => { - DataType::Decimal(Some(p as usize), Some(s as usize)) - }, - ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)), - ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)), - } - }, - - // --------------------------------- - // temporal - // --------------------------------- - SQLDataType::Date => DataType::Date, - SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds), - SQLDataType::Time(_, tz) => match tz { - TimezoneInfo::None => DataType::Time, - _ => { - polars_bail!(SQLInterface: "`time` with timezone is not supported; found tz={}", tz) - }, - }, - SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None), - SQLDataType::Timestamp(prec, tz) => match tz { - TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None), - _ => { - polars_bail!(SQLInterface: "`timestamp` with timezone is not (yet) supported") - }, - }, - - // --------------------------------- - // string - // --------------------------------- - SQLDataType::Char(_) - | SQLDataType::CharVarying(_) - | SQLDataType::Character(_) - | SQLDataType::CharacterVarying(_) - | SQLDataType::Clob(_) - | SQLDataType::String(_) - | SQLDataType::Text - | SQLDataType::Uuid - | SQLDataType::Varchar(_) => DataType::String, - - // --------------------------------- - // custom - // --------------------------------- - SQLDataType::Custom(ObjectName(idents), _) => match idents.as_slice() { - [Ident { value, .. }] => match value.to_lowercase().as_str() { - // these integer types are not supported by the PostgreSQL core distribution, - // but they ARE available via `pguint` (https://github.com/petere/pguint), an - // extension maintained by one of the PostgreSQL core developers. - "uint1" => DataType::UInt8, - "uint2" => DataType::UInt16, - "uint4" | "uint" => DataType::UInt32, - "uint8" => DataType::UInt64, - // `pguint` also provides a 1 byte (8bit) integer type alias - "int1" => DataType::Int8, - _ => { - polars_bail!(SQLInterface: "datatype {:?} is not currently supported", value) - }, - }, - _ => { - polars_bail!(SQLInterface: "datatype {:?} is not currently supported", idents) - }, - }, - _ => { - polars_bail!(SQLInterface: "datatype {:?} is not currently supported", data_type) - }, - }) -} - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +/// Categorises the type of (allowed) subquery constraint pub enum SubqueryRestriction { - // SingleValue, + /// Subquery must return a single column SingleColumn, // SingleRow, + // SingleValue, // Any } @@ -248,7 +79,7 @@ impl SQLExprVisitor<'_> { }) .collect::>>()?; - Series::from_any_values("", &array_elements, true) + Series::from_any_values(PlSmallStr::EMPTY, &array_elements, true) } fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult { @@ -446,7 +277,7 @@ impl SQLExprVisitor<'_> { /// /// e.g. column fn visit_identifier(&self, ident: &Ident) -> PolarsResult { - Ok(col(&ident.value)) + Ok(col(ident.value.as_str())) } /// Visit a compound SQL identifier @@ -542,16 +373,11 @@ impl SQLExprVisitor<'_> { (Some(name.clone()), Some(s), None) }, // identify "CAST(expr AS type) string" and/or "expr::type string" expressions - ( - Expr::Cast { - expr, data_type, .. - }, - Expr::Literal(LiteralValue::String(s)), - ) => { + (Expr::Cast { expr, dtype, .. }, Expr::Literal(LiteralValue::String(s))) => { if let Expr::Column(name) = &**expr { - (Some(name.clone()), Some(s), Some(data_type)) + (Some(name.clone()), Some(s), Some(dtype)) } else { - (None, Some(s), Some(data_type)) + (None, Some(s), Some(dtype)) } }, _ => (None, None, None), @@ -879,7 +705,7 @@ impl SQLExprVisitor<'_> { fn visit_cast( &mut self, expr: &SQLExpr, - data_type: &SQLDataType, + dtype: &SQLDataType, format: &Option, cast_kind: &CastKind, ) -> PolarsResult { @@ -891,10 +717,10 @@ impl SQLExprVisitor<'_> { let expr = self.visit_expr(expr)?; #[cfg(feature = "json")] - if data_type == &SQLDataType::JSON { + if dtype == &SQLDataType::JSON { return Ok(expr.str().json_decode(None, None)); } - let polars_type = map_sql_polars_datatype(data_type)?; + let polars_type = map_sql_dtype_to_polars(dtype)?; Ok(match cast_kind { CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type), CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type), @@ -989,7 +815,7 @@ impl SQLExprVisitor<'_> { }, } }, - SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.into()), + SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.as_str().into()), other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other), }) } @@ -1178,7 +1004,7 @@ pub fn sql_expr>(s: S) -> PolarsResult { Ok(match &expr { SelectItem::ExprWithAlias { expr, alias } => { let expr = parse_sql_expr(expr, &mut ctx, None)?; - expr.alias(&alias.value) + expr.alias(alias.value.as_str()) }, SelectItem::UnnamedExpr(expr) => parse_sql_expr(expr, &mut ctx, None)?, _ => polars_bail!(SQLInterface: "unable to parse '{}' as Expr", s.as_ref()), @@ -1324,24 +1150,6 @@ pub(crate) fn adjust_one_indexed_param(idx: Expr, null_if_zero: bool) -> Expr { } } -fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { - let n_bits = b.len(); - if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { - polars_bail!( - SQLSyntax: - "bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits - ) - } - let s = b.as_str(); - Ok(lit(match n_bits { - 0 => b"".to_vec(), - 1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - 17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), - })) -} - pub(crate) fn resolve_compound_identifier( ctx: &mut SQLContext, idents: &[Ident], @@ -1358,32 +1166,36 @@ pub(crate) fn resolve_compound_identifier( Ok(Arc::new(if let Some(active_schema) = active_schema { active_schema.clone() } else { - Schema::new() + Schema::default() })) }?; let col_dtype: PolarsResult<(Expr, Option<&DataType>)> = if lf.is_none() && schema.is_empty() { - Ok((col(&ident_root.value), None)) + Ok((col(ident_root.value.as_str()), None)) } else { let name = &remaining_idents.next().unwrap().value; if lf.is_some() && name == "*" { return Ok(schema .iter_names() - .map(|name| col(name)) + .map(|name| col(name.clone())) .collect::>()); } else if let Some((_, name, dtype)) = schema.get_full(name) { - let resolved = &ctx.resolve_name(&ident_root.value, name); + let resolved = ctx.resolve_name(&ident_root.value, name); + let resolved = resolved.as_str(); Ok(( if name != resolved { - col(resolved).alias(name) + col(resolved).alias(name.clone()) } else { - col(name) + col(name.clone()) }, Some(dtype), )) } else if lf.is_none() { remaining_idents = idents.iter().skip(1); - Ok((col(&ident_root.value), schema.get(&ident_root.value))) + Ok(( + col(ident_root.value.as_str()), + schema.get(&ident_root.value), + )) } else { polars_bail!( SQLInterface: "no column named '{}' found in table '{}'", diff --git a/crates/polars-sql/src/types.rs b/crates/polars-sql/src/types.rs new file mode 100644 index 000000000000..800ead8c233e --- /dev/null +++ b/crates/polars-sql/src/types.rs @@ -0,0 +1,208 @@ +//! This module supports mapping SQL datatypes to Polars datatypes. +//! +//! It also provides utility functions for working with SQL datatypes. +use polars_core::datatypes::{DataType, TimeUnit}; +use polars_core::export::regex::{Regex, RegexBuilder}; +use polars_error::{polars_bail, PolarsResult}; +use polars_plan::dsl::{lit, Expr}; +use sqlparser::ast::{ + ArrayElemTypeDef, DataType as SQLDataType, ExactNumberInfo, Ident, ObjectName, TimezoneInfo, +}; + +static DATETIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +static DATE_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); +static TIME_LITERAL_RE: std::sync::OnceLock = std::sync::OnceLock::new(); + +pub fn bitstring_to_bytes_literal(b: &String) -> PolarsResult { + let n_bits = b.len(); + if !b.chars().all(|c| c == '0' || c == '1') || n_bits > 64 { + polars_bail!( + SQLSyntax: + "bit string literal should contain only 0s and 1s and have length <= 64; found '{}' with length {}", b, n_bits + ) + } + let s = b.as_str(); + Ok(lit(match n_bits { + 0 => b"".to_vec(), + 1..=8 => u8::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + 9..=16 => u16::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + 17..=32 => u32::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + _ => u64::from_str_radix(s, 2).unwrap().to_be_bytes().to_vec(), + })) +} + +pub fn is_iso_datetime(value: &str) -> bool { + let dtm_regex = DATETIME_LITERAL_RE.get_or_init(|| { + RegexBuilder::new( + r"^\d{4}-[01]\d-[0-3]\d[ T](?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$", + ) + .build() + .unwrap() + }); + dtm_regex.is_match(value) +} + +pub fn is_iso_date(value: &str) -> bool { + let dt_regex = DATE_LITERAL_RE.get_or_init(|| { + RegexBuilder::new(r"^\d{4}-[01]\d-[0-3]\d$") + .build() + .unwrap() + }); + dt_regex.is_match(value) +} + +pub fn is_iso_time(value: &str) -> bool { + let tm_regex = TIME_LITERAL_RE.get_or_init(|| { + RegexBuilder::new(r"^(?:[01][0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9](\.\d{1,9})?$") + .build() + .unwrap() + }); + tm_regex.is_match(value) +} + +fn timeunit_from_precision(prec: &Option) -> PolarsResult { + Ok(match prec { + None => TimeUnit::Microseconds, + Some(n) if (1u64..=3u64).contains(n) => TimeUnit::Milliseconds, + Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, + Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, + Some(n) => { + polars_bail!(SQLSyntax: "invalid temporal type precision (expected 1-9, found {})", n) + }, + }) +} + +pub(crate) fn map_sql_dtype_to_polars(dtype: &SQLDataType) -> PolarsResult { + Ok(match dtype { + // --------------------------------- + // array/list + // --------------------------------- + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, _)) => { + DataType::List(Box::new(map_sql_dtype_to_polars(inner_type)?)) + }, + + // --------------------------------- + // binary + // --------------------------------- + SQLDataType::Bytea + | SQLDataType::Bytes(_) + | SQLDataType::Binary(_) + | SQLDataType::Blob(_) + | SQLDataType::Varbinary(_) => DataType::Binary, + + // --------------------------------- + // boolean + // --------------------------------- + SQLDataType::Boolean | SQLDataType::Bool => DataType::Boolean, + + // --------------------------------- + // signed integer + // --------------------------------- + SQLDataType::Int(_) | SQLDataType::Integer(_) => DataType::Int32, + SQLDataType::Int2(_) | SQLDataType::SmallInt(_) => DataType::Int16, + SQLDataType::Int4(_) | SQLDataType::MediumInt(_) => DataType::Int32, + SQLDataType::Int8(_) | SQLDataType::BigInt(_) => DataType::Int64, + SQLDataType::TinyInt(_) => DataType::Int8, + + // --------------------------------- + // unsigned integer: the following do not map to PostgreSQL types/syntax, but + // are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)"). + // --------------------------------- + SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below + SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, + SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, + SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32, + SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => { + DataType::UInt64 + }, + + // --------------------------------- + // float + // --------------------------------- + SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => { + DataType::Float64 + }, + SQLDataType::Float(n_bytes) => match n_bytes { + Some(n) if (1u64..=24u64).contains(n) => DataType::Float32, + Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, + Some(n) => { + polars_bail!(SQLSyntax: "unsupported `float` size (expected a value between 1 and 53, found {})", n) + }, + None => DataType::Float64, + }, + SQLDataType::Float4 | SQLDataType::Real => DataType::Float32, + + // --------------------------------- + // decimal + // --------------------------------- + #[cfg(feature = "dtype-decimal")] + SQLDataType::Dec(info) | SQLDataType::Decimal(info) | SQLDataType::Numeric(info) => { + match *info { + ExactNumberInfo::PrecisionAndScale(p, s) => { + DataType::Decimal(Some(p as usize), Some(s as usize)) + }, + ExactNumberInfo::Precision(p) => DataType::Decimal(Some(p as usize), Some(0)), + ExactNumberInfo::None => DataType::Decimal(Some(38), Some(9)), + } + }, + + // --------------------------------- + // temporal + // --------------------------------- + SQLDataType::Date => DataType::Date, + SQLDataType::Interval => DataType::Duration(TimeUnit::Microseconds), + SQLDataType::Time(_, tz) => match tz { + TimezoneInfo::None => DataType::Time, + _ => { + polars_bail!(SQLInterface: "`time` with timezone is not supported; found tz={}", tz) + }, + }, + SQLDataType::Datetime(prec) => DataType::Datetime(timeunit_from_precision(prec)?, None), + SQLDataType::Timestamp(prec, tz) => match tz { + TimezoneInfo::None => DataType::Datetime(timeunit_from_precision(prec)?, None), + _ => { + polars_bail!(SQLInterface: "`timestamp` with timezone is not (yet) supported") + }, + }, + + // --------------------------------- + // string + // --------------------------------- + SQLDataType::Char(_) + | SQLDataType::CharVarying(_) + | SQLDataType::Character(_) + | SQLDataType::CharacterVarying(_) + | SQLDataType::Clob(_) + | SQLDataType::String(_) + | SQLDataType::Text + | SQLDataType::Uuid + | SQLDataType::Varchar(_) => DataType::String, + + // --------------------------------- + // custom + // --------------------------------- + SQLDataType::Custom(ObjectName(idents), _) => match idents.as_slice() { + [Ident { value, .. }] => match value.to_lowercase().as_str() { + // these integer types are not supported by the PostgreSQL core distribution, + // but they ARE available via `pguint` (https://github.com/petere/pguint), an + // extension maintained by one of the PostgreSQL core developers. + "uint1" => DataType::UInt8, + "uint2" => DataType::UInt16, + "uint4" | "uint" => DataType::UInt32, + "uint8" => DataType::UInt64, + // `pguint` also provides a 1 byte (8bit) integer type alias + "int1" => DataType::Int8, + _ => { + polars_bail!(SQLInterface: "datatype {:?} is not currently supported", value) + }, + }, + _ => { + polars_bail!(SQLInterface: "datatype {:?} is not currently supported", idents) + }, + }, + _ => { + polars_bail!(SQLInterface: "datatype {:?} is not currently supported", dtype) + }, + }) +} diff --git a/crates/polars-sql/tests/issues.rs b/crates/polars-sql/tests/issues.rs index 31c0a89e84ff..10ee22db49d3 100644 --- a/crates/polars-sql/tests/issues.rs +++ b/crates/polars-sql/tests/issues.rs @@ -113,7 +113,7 @@ fn iss_8395() -> PolarsResult<()> { // assert that the df only contains [vegetables, seafood] let s = df.column("category")?.unique()?.sort(Default::default())?; - let expected = Series::new("category", &["seafood", "vegetables"]); + let expected = Series::new("category".into(), &["seafood", "vegetables"]); assert!(s.equals(&expected)); Ok(()) } diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index c37e12ed3040..b84c6e681cd2 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -4,8 +4,11 @@ use polars_sql::*; use polars_time::Duration; fn create_sample_df() -> DataFrame { - let a = Series::new("a", (1..10000i64).map(|i| i / 100).collect::>()); - let b = Series::new("b", 1..10000i64); + let a = Series::new( + "a".into(), + (1..10000i64).map(|i| i / 100).collect::>(), + ); + let b = Series::new("b".into(), 1..10000i64); DataFrame::new(vec![a, b]).unwrap() } diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index e5be8e598b60..2657ec443077 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -3,8 +3,8 @@ use polars_lazy::prelude::*; use polars_sql::*; fn create_ctx() -> SQLContext { - let a = Series::new("a", (1..10i64).map(|i| i / 100).collect::>()); - let b = Series::new("b", 1..10i64); + let a = Series::new("a".into(), (1..10i64).map(|i| i / 100).collect::>()); + let b = Series::new("b".into(), 1..10i64); let df = DataFrame::new(vec![a, b]).unwrap().lazy(); let mut ctx = SQLContext::new(); ctx.register("df", df); diff --git a/crates/polars-sql/tests/udf.rs b/crates/polars-sql/tests/udf.rs index 66eb0353b07d..3ccd1c4d6395 100644 --- a/crates/polars-sql/tests/udf.rs +++ b/crates/polars-sql/tests/udf.rs @@ -33,10 +33,10 @@ impl FunctionRegistry for MyFunctionRegistry { #[test] fn test_udfs() -> PolarsResult<()> { let my_custom_sum = UserDefinedFunction::new( - "my_custom_sum", + "my_custom_sum".into(), vec![ - Field::new("a", DataType::Int32), - Field::new("b", DataType::Int32), + Field::new("a".into(), DataType::Int32), + Field::new("b".into(), DataType::Int32), ], GetOutput::same_type(), move |s: &mut [Series]| { @@ -68,10 +68,10 @@ fn test_udfs() -> PolarsResult<()> { // create a new UDF to be registered on the context let my_custom_divide = UserDefinedFunction::new( - "my_custom_divide", + "my_custom_divide".into(), vec![ - Field::new("a", DataType::Int32), - Field::new("b", DataType::Int32), + Field::new("a".into(), DataType::Int32), + Field::new("b".into(), DataType::Int32), ], GetOutput::same_type(), move |s: &mut [Series]| { diff --git a/crates/polars-stream/Cargo.toml b/crates/polars-stream/Cargo.toml index a8741189f7dd..e2a7d0c45649 100644 --- a/crates/polars-stream/Cargo.toml +++ b/crates/polars-stream/Cargo.toml @@ -12,9 +12,11 @@ description = "Private crate for the streaming execution engine for the Polars D atomic-waker = { workspace = true } crossbeam-deque = { workspace = true } crossbeam-utils = { workspace = true } +futures = { workspace = true } +memmap = { workspace = true } parking_lot = { workspace = true } pin-project-lite = { workspace = true } -polars-io = { workspace = true, features = ["async"] } +polars-io = { workspace = true, features = ["async", "cloud", "aws"] } polars-utils = { workspace = true } rand = { workspace = true } rayon = { workspace = true } @@ -25,8 +27,9 @@ tokio = { workspace = true } polars-core = { workspace = true } polars-error = { workspace = true } polars-expr = { workspace = true } -polars-mem-engine = { workspace = true } -polars-plan = { workspace = true } +polars-mem-engine = { workspace = true, features = ["parquet"] } +polars-parquet = { workspace = true } +polars-plan = { workspace = true, features = ["parquet"] } [build-dependencies] version_check = { workspace = true } diff --git a/crates/polars-stream/src/async_executor/mod.rs b/crates/polars-stream/src/async_executor/mod.rs index eb549cc7c1fa..dec560845b09 100644 --- a/crates/polars-stream/src/async_executor/mod.rs +++ b/crates/polars-stream/src/async_executor/mod.rs @@ -15,7 +15,7 @@ use parking_lot::Mutex; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use slotmap::SlotMap; -pub use task::JoinHandle; +pub use task::{AbortOnDropHandle, JoinHandle}; use task::{CancelHandle, Runnable}; static NUM_EXECUTOR_THREADS: AtomicUsize = AtomicUsize::new(0); @@ -42,18 +42,23 @@ pub enum TaskPriority { } /// Metadata associated with a task to help schedule it and clean it up. +struct ScopedTaskMetadata { + task_key: TaskKey, + completed_tasks: Weak>>, +} + struct TaskMetadata { priority: TaskPriority, freshly_spawned: AtomicBool, - - task_key: TaskKey, - completed_tasks: Weak>>, + scoped: Option, } impl Drop for TaskMetadata { fn drop(&mut self) { - if let Some(completed_tasks) = self.completed_tasks.upgrade() { - completed_tasks.lock().push(self.task_key); + if let Some(scoped) = &self.scoped { + if let Some(completed_tasks) = scoped.completed_tasks.upgrade() { + completed_tasks.lock().push(scoped.task_key); + } } } } @@ -296,10 +301,12 @@ impl<'scope, 'env> TaskScope<'scope, 'env> { fut, on_wake, TaskMetadata { - task_key, priority, freshly_spawned: AtomicBool::new(true), - completed_tasks: Arc::downgrade(&self.completed_tasks), + scoped: Some(ScopedTaskMetadata { + task_key, + completed_tasks: Arc::downgrade(&self.completed_tasks), + }), }, ) }; @@ -338,6 +345,25 @@ where } } +pub fn spawn(priority: TaskPriority, fut: F) -> JoinHandle +where + ::Output: Send + 'static, +{ + let executor = Executor::global(); + let on_wake = move |task| executor.schedule_task(task); + let (runnable, join_handle) = task::spawn( + fut, + on_wake, + TaskMetadata { + priority, + freshly_spawned: AtomicBool::new(true), + scoped: None, + }, + ); + runnable.schedule(); + join_handle +} + fn random_permutation(len: u32, rng: &mut R) -> impl Iterator { let modulus = len.next_power_of_two(); let halfwidth = modulus.trailing_zeros() / 2; diff --git a/crates/polars-stream/src/async_executor/task.rs b/crates/polars-stream/src/async_executor/task.rs index b87b2a7b4be3..9991377eb718 100644 --- a/crates/polars-stream/src/async_executor/task.rs +++ b/crates/polars-stream/src/async_executor/task.rs @@ -278,6 +278,10 @@ impl Runnable { pub struct JoinHandle(Option>>); pub struct CancelHandle(Weak); +pub struct AbortOnDropHandle { + join_handle: JoinHandle, + cancel_handle: CancelHandle, +} impl JoinHandle { pub fn cancel_handle(&self) -> CancelHandle { @@ -305,15 +309,38 @@ impl Future for JoinHandle { } impl CancelHandle { - pub fn cancel(self) { + pub fn cancel(&self) { if let Some(t) = self.0.upgrade() { t.cancel(); } } } -#[allow(unused)] -pub fn spawn(future: F, schedule: S, metadata: M) -> JoinHandle +impl AbortOnDropHandle { + pub fn new(join_handle: JoinHandle) -> Self { + let cancel_handle = join_handle.cancel_handle(); + Self { + join_handle, + cancel_handle, + } + } +} + +impl Future for AbortOnDropHandle { + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.join_handle).poll(cx) + } +} + +impl Drop for AbortOnDropHandle { + fn drop(&mut self) { + self.cancel_handle.cancel(); + } +} + +pub fn spawn(future: F, schedule: S, metadata: M) -> (Runnable, JoinHandle) where F: Future + Send + 'static, F::Output: Send + 'static, @@ -321,7 +348,7 @@ where M: Send + Sync + 'static, { let task = unsafe { Task::spawn(future, schedule, metadata) }; - JoinHandle(Some(task)) + (task.clone().into_runnable(), task.into_join_handle()) } /// Takes a future and turns it into a runnable task with associated metadata. diff --git a/crates/polars-stream/src/async_primitives/distributor_channel.rs b/crates/polars-stream/src/async_primitives/distributor_channel.rs index 5bdeb7e56866..21af7b53d7d1 100644 --- a/crates/polars-stream/src/async_primitives/distributor_channel.rs +++ b/crates/polars-stream/src/async_primitives/distributor_channel.rs @@ -198,6 +198,8 @@ impl Sender { } impl Receiver { + /// Note: This intentionally takes `&mut` to ensure it is only accessed in a single-threaded + /// manner. pub async fn recv(&mut self) -> Result { loop { // Fast-path. diff --git a/crates/polars-stream/src/execute.rs b/crates/polars-stream/src/execute.rs index 5f3bdebf7d36..d17bc89bd6ad 100644 --- a/crates/polars-stream/src/execute.rs +++ b/crates/polars-stream/src/execute.rs @@ -205,10 +205,11 @@ fn run_subgraph( for input in &node.inputs { let sender = graph.pipes[*input].sender; if let Some(count) = num_send_ports_not_yet_ready.get_mut(sender) { - assert!(*count > 0); - *count -= 1; - if *count == 0 { - ready.push(sender); + if *count > 0 { + *count -= 1; + if *count == 0 { + ready.push(sender); + } } } } @@ -247,7 +248,7 @@ pub fn execute_graph( if polars_core::config::verbose() { eprintln!("polars-stream: updating graph state"); } - graph.update_all_states(); + graph.update_all_states()?; let (nodes, pipes) = find_runnable_subgraph(graph); if polars_core::config::verbose() { for node in &nodes { diff --git a/crates/polars-stream/src/expression.rs b/crates/polars-stream/src/expression.rs index a6e41728d111..3c1b9445997c 100644 --- a/crates/polars-stream/src/expression.rs +++ b/crates/polars-stream/src/expression.rs @@ -6,7 +6,7 @@ use polars_error::PolarsResult; use polars_expr::prelude::{ExecutionState, PhysicalExpr}; #[derive(Clone)] -pub(crate) struct StreamExpr { +pub struct StreamExpr { inner: Arc, // Whether the expression can be re-entering the engine (e.g. a function use the lazy api // within that function) @@ -14,18 +14,14 @@ pub(crate) struct StreamExpr { } impl StreamExpr { - pub(crate) fn new(phys_expr: Arc, reentrant: bool) -> Self { + pub fn new(phys_expr: Arc, reentrant: bool) -> Self { Self { inner: phys_expr, reentrant, } } - pub(crate) async fn evaluate( - &self, - df: &DataFrame, - state: &ExecutionState, - ) -> PolarsResult { + pub async fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { if self.reentrant { let state = state.clone(); let phys_expr = self.inner.clone(); diff --git a/crates/polars-stream/src/graph.rs b/crates/polars-stream/src/graph.rs index 055d8df4a5ae..572c1f1c306d 100644 --- a/crates/polars-stream/src/graph.rs +++ b/crates/polars-stream/src/graph.rs @@ -1,3 +1,4 @@ +use polars_error::PolarsResult; use slotmap::{SecondaryMap, SlotMap}; use crate::nodes::ComputeNode; @@ -64,11 +65,13 @@ impl Graph { } /// Updates all the nodes' states until a fixed point is reached. - pub fn update_all_states(&mut self) { + pub fn update_all_states(&mut self) -> PolarsResult<()> { let mut to_update: Vec<_> = self.nodes.keys().collect(); let mut scheduled_for_update: SecondaryMap = self.nodes.keys().map(|k| (k, ())).collect(); + let verbose = std::env::var("POLARS_VERBOSE_STATE_UPDATE").as_deref() == Ok("1"); + let mut recv_state = Vec::new(); let mut send_state = Vec::new(); while let Some(node_key) = to_update.pop() { @@ -82,15 +85,25 @@ impl Graph { send_state.extend(node.outputs.iter().map(|o| self.pipes[*o].recv_state)); // Compute the new state of this node given its environment. - // eprintln!("updating {}, before: {recv_state:?} {send_state:?}", node.compute.name()); - node.compute.update_state(&mut recv_state, &mut send_state); - // eprintln!("updating {}, after: {recv_state:?} {send_state:?}", node.compute.name()); + if verbose { + eprintln!( + "updating {}, before: {recv_state:?} {send_state:?}", + node.compute.name() + ); + } + node.compute + .update_state(&mut recv_state, &mut send_state)?; + if verbose { + eprintln!( + "updating {}, after: {recv_state:?} {send_state:?}", + node.compute.name() + ); + } // Propagate information. for (input, state) in node.inputs.iter().zip(recv_state.iter()) { let pipe = &mut self.pipes[*input]; if pipe.recv_state != *state { - // eprintln!("transitioning input pipe from {:?} to {state:?}", pipe.recv_state); assert!(pipe.recv_state != PortState::Done, "implementation error: state transition from Done to Blocked/Ready attempted"); pipe.recv_state = *state; if scheduled_for_update.insert(pipe.sender, ()).is_none() { @@ -102,7 +115,6 @@ impl Graph { for (output, state) in node.outputs.iter().zip(send_state.iter()) { let pipe = &mut self.pipes[*output]; if pipe.send_state != *state { - // eprintln!("transitioning output pipe from {:?} to {state:?}", pipe.send_state); assert!(pipe.send_state != PortState::Done, "implementation error: state transition from Done to Blocked/Ready attempted"); pipe.send_state = *state; if scheduled_for_update.insert(pipe.receiver, ()).is_none() { @@ -111,6 +123,7 @@ impl Graph { } } } + Ok(()) } } diff --git a/crates/polars-stream/src/morsel.rs b/crates/polars-stream/src/morsel.rs index 9ba0b4b0288e..766129da0a20 100644 --- a/crates/polars-stream/src/morsel.rs +++ b/crates/polars-stream/src/morsel.rs @@ -67,7 +67,7 @@ impl SourceToken { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Morsel { /// The data contained in this morsel. df: DataFrame, @@ -152,4 +152,8 @@ impl Morsel { pub fn source_token(&self) -> &SourceToken { &self.source_token } + + pub fn replace_source_token(&mut self, new_token: SourceToken) -> SourceToken { + core::mem::replace(&mut self.source_token, new_token) + } } diff --git a/crates/polars-stream/src/nodes/filter.rs b/crates/polars-stream/src/nodes/filter.rs index 8a19b1a27986..9f0b0301ef91 100644 --- a/crates/polars-stream/src/nodes/filter.rs +++ b/crates/polars-stream/src/nodes/filter.rs @@ -18,9 +18,10 @@ impl ComputeNode for FilterNode { "filter" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 1 && send.len() == 1); recv.swap_with_slice(send); + Ok(()) } fn spawn<'env, 's>( diff --git a/crates/polars-stream/src/nodes/in_memory_map.rs b/crates/polars-stream/src/nodes/in_memory_map.rs index 09769172c430..3a8bff496a18 100644 --- a/crates/polars-stream/src/nodes/in_memory_map.rs +++ b/crates/polars-stream/src/nodes/in_memory_map.rs @@ -39,7 +39,7 @@ impl ComputeNode for InMemoryMapNode { } } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 1 && send.len() == 1); // If the output doesn't want any more data, transition to being done. @@ -55,9 +55,8 @@ impl ComputeNode for InMemoryMapNode { } = self { if recv[0] == PortState::Done { - let df = sink_node.get_output().unwrap(); - let mut source_node = - InMemorySourceNode::new(Arc::new(map.call_udf(df.unwrap()).unwrap())); + let df = sink_node.get_output()?; + let mut source_node = InMemorySourceNode::new(Arc::new(map.call_udf(df.unwrap())?)); source_node.initialize(*num_pipelines); *self = Self::Source(source_node); } @@ -65,18 +64,19 @@ impl ComputeNode for InMemoryMapNode { match self { Self::Sink { sink_node, .. } => { - sink_node.update_state(recv, &mut []); + sink_node.update_state(recv, &mut [])?; send[0] = PortState::Blocked; }, Self::Source(source_node) => { recv[0] = PortState::Done; - source_node.update_state(&mut [], send); + source_node.update_state(&mut [], send)?; }, Self::Done => { recv[0] = PortState::Done; send[0] = PortState::Done; }, } + Ok(()) } fn is_memory_intensive_pipeline_blocker(&self) -> bool { diff --git a/crates/polars-stream/src/nodes/in_memory_sink.rs b/crates/polars-stream/src/nodes/in_memory_sink.rs index 0a4750d7b8b9..afd6ccfd95cc 100644 --- a/crates/polars-stream/src/nodes/in_memory_sink.rs +++ b/crates/polars-stream/src/nodes/in_memory_sink.rs @@ -26,7 +26,7 @@ impl ComputeNode for InMemorySinkNode { "in_memory_sink" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(send.is_empty()); assert!(recv.len() == 1); @@ -35,6 +35,7 @@ impl ComputeNode for InMemorySinkNode { if recv[0] != PortState::Done { recv[0] = PortState::Ready; } + Ok(()) } fn is_memory_intensive_pipeline_blocker(&self) -> bool { diff --git a/crates/polars-stream/src/nodes/in_memory_source.rs b/crates/polars-stream/src/nodes/in_memory_source.rs index 826f9e5e5c83..45630eb7aab0 100644 --- a/crates/polars-stream/src/nodes/in_memory_source.rs +++ b/crates/polars-stream/src/nodes/in_memory_source.rs @@ -34,7 +34,7 @@ impl ComputeNode for InMemorySourceNode { self.seq = AtomicU64::new(0); } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.is_empty()); assert!(send.len() == 1); @@ -52,6 +52,7 @@ impl ComputeNode for InMemorySourceNode { } else { send[0] = PortState::Ready; } + Ok(()) } fn spawn<'env, 's>( diff --git a/crates/polars-stream/src/nodes/map.rs b/crates/polars-stream/src/nodes/map.rs index 44587193f23d..007dfa921672 100644 --- a/crates/polars-stream/src/nodes/map.rs +++ b/crates/polars-stream/src/nodes/map.rs @@ -20,9 +20,10 @@ impl ComputeNode for MapNode { "map" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 1 && send.len() == 1); recv.swap_with_slice(send); + Ok(()) } fn spawn<'env, 's>( diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index fecc1b8e5abe..4c71380e0ad4 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -3,7 +3,9 @@ pub mod in_memory_map; pub mod in_memory_sink; pub mod in_memory_source; pub mod map; +pub mod multiplexer; pub mod ordered_union; +pub mod parquet_source; pub mod reduce; pub mod select; pub mod simple_projection; @@ -44,7 +46,7 @@ pub trait ComputeNode: Send { /// Similarly, for each output pipe `send` will contain the respective /// state of the input port that pipe is connected to when called, and you /// must update it to contain the desired state of your output port. - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]); + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()>; /// If this node (in its current state) is a pipeline blocker, and whether /// this is memory intensive or not. diff --git a/crates/polars-stream/src/nodes/multiplexer.rs b/crates/polars-stream/src/nodes/multiplexer.rs new file mode 100644 index 000000000000..65f2e752d28d --- /dev/null +++ b/crates/polars-stream/src/nodes/multiplexer.rs @@ -0,0 +1,206 @@ +use std::collections::VecDeque; + +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; + +use super::compute_node_prelude::*; +use crate::morsel::SourceToken; + +// TODO: replace this with an out-of-core buffering solution. +enum BufferedStream { + Open(VecDeque), + Closed, +} + +impl BufferedStream { + fn new() -> Self { + Self::Open(VecDeque::new()) + } +} + +pub struct MultiplexerNode { + buffers: Vec, +} + +impl MultiplexerNode { + pub fn new() -> Self { + Self { + buffers: Vec::default(), + } + } +} + +impl ComputeNode for MultiplexerNode { + fn name(&self) -> &str { + "multiplexer" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { + assert!(recv.len() == 1 && !send.is_empty()); + + // Initialize buffered streams, and mark those for which the receiver + // is no longer interested as closed. + self.buffers.resize_with(send.len(), BufferedStream::new); + for (s, b) in send.iter().zip(&mut self.buffers) { + if *s == PortState::Done { + *b = BufferedStream::Closed; + } + } + + // Check if either the input is done, or all outputs are done. + let input_done = recv[0] == PortState::Done + && self.buffers.iter().all(|b| match b { + BufferedStream::Open(v) => v.is_empty(), + BufferedStream::Closed => true, + }); + let output_done = send.iter().all(|p| *p == PortState::Done); + + // If either side is done, everything is done. + if input_done || output_done { + recv[0] = PortState::Done; + for s in send { + *s = PortState::Done; + } + return Ok(()); + } + + let all_blocked = send.iter().all(|p| *p == PortState::Blocked); + + // Pass along the input state to the output. + for (i, s) in send.iter_mut().enumerate() { + let buffer_empty = match &self.buffers[i] { + BufferedStream::Open(v) => v.is_empty(), + BufferedStream::Closed => true, + }; + *s = if buffer_empty && recv[0] == PortState::Done { + PortState::Done + } else if !buffer_empty || recv[0] == PortState::Ready { + PortState::Ready + } else { + PortState::Blocked + }; + } + + // We say we are ready to receive unless all outputs are blocked. + recv[0] = if all_blocked { + PortState::Blocked + } else { + PortState::Ready + }; + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv: &mut [Option>], + send: &mut [Option>], + _state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(recv.len() == 1 && !send.is_empty()); + assert!(self.buffers.len() == send.len()); + + enum Listener<'a> { + Active(UnboundedSender), + Buffering(&'a mut VecDeque), + Inactive, + } + + let buffered_source_token = SourceToken::new(); + + let (mut buf_senders, buf_receivers): (Vec<_>, Vec<_>) = self + .buffers + .iter_mut() + .enumerate() + .map(|(port_idx, buffer)| { + if let BufferedStream::Open(buf) = buffer { + if send[port_idx].is_some() { + // TODO: replace with a bounded channel and store data + // out-of-core beyond a certain size. + let (rx, tx) = unbounded_channel(); + (Listener::Active(rx), Some((buf, tx))) + } else { + (Listener::Buffering(buf), None) + } + } else { + (Listener::Inactive, None) + } + }) + .unzip(); + + // TODO: parallel multiplexing. + if let Some(mut receiver) = recv[0].take().map(|r| r.serial()) { + let buffered_source_token = buffered_source_token.clone(); + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + loop { + let Ok(morsel) = receiver.recv().await else { + break; + }; + + let mut anyone_interested = false; + let mut active_listener_interested = false; + for buf_sender in &mut buf_senders { + match buf_sender { + Listener::Active(s) => match s.send(morsel.clone()) { + Ok(_) => { + anyone_interested = true; + active_listener_interested = true; + }, + Err(_) => *buf_sender = Listener::Inactive, + }, + Listener::Buffering(b) => { + // Make sure to count buffered morsels as + // consumed to not block the source. + let mut m = morsel.clone(); + m.take_consume_token(); + b.push_front(m); + anyone_interested = true; + }, + Listener::Inactive => {}, + } + } + + if !anyone_interested { + break; + } + + // If only buffering inputs are left, or we got a stop + // request from an input reading from old buffered data, + // request a stop from the source. + if !active_listener_interested || buffered_source_token.stop_requested() { + morsel.source_token().stop(); + } + } + + Ok(()) + })); + } + + for (send_port, opt_buf_recv) in send.iter_mut().zip(buf_receivers) { + if let Some((buf, mut rx)) = opt_buf_recv { + let mut sender = send_port.take().unwrap().serial(); + + let buffered_source_token = buffered_source_token.clone(); + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + // First we try to flush all the old buffered data. + while let Some(mut morsel) = buf.pop_back() { + morsel.replace_source_token(buffered_source_token.clone()); + if sender.send(morsel).await.is_err() + || buffered_source_token.stop_requested() + { + break; + } + } + + // Then send along data from the multiplexer. + while let Some(morsel) = rx.recv().await { + if sender.send(morsel).await.is_err() { + break; + } + } + Ok(()) + })); + } + } + } +} diff --git a/crates/polars-stream/src/nodes/ordered_union.rs b/crates/polars-stream/src/nodes/ordered_union.rs index f38c306505b4..3c72d9cc6e15 100644 --- a/crates/polars-stream/src/nodes/ordered_union.rs +++ b/crates/polars-stream/src/nodes/ordered_union.rs @@ -23,7 +23,7 @@ impl ComputeNode for OrderedUnionNode { "ordered_union" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(self.cur_input_idx <= recv.len() && send.len() == 1); // Skip inputs that are done. @@ -46,6 +46,7 @@ impl ComputeNode for OrderedUnionNode { // Set the morsel offset one higher than any sent so far. self.morsel_offset = self.max_morsel_seq_sent.successor(); + Ok(()) } fn spawn<'env, 's>( diff --git a/crates/polars-stream/src/nodes/parquet_source.rs b/crates/polars-stream/src/nodes/parquet_source.rs new file mode 100644 index 000000000000..bf5d4262fed6 --- /dev/null +++ b/crates/polars-stream/src/nodes/parquet_source.rs @@ -0,0 +1,1862 @@ +use std::future::Future; +use std::path::PathBuf; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; + +use futures::stream::FuturesUnordered; +use futures::StreamExt; +use polars_core::config; +use polars_core::frame::DataFrame; +use polars_core::prelude::{ + ArrowSchema, ChunkFull, DataType, IdxCa, InitHashMaps, PlHashMap, StringChunked, +}; +use polars_core::series::{IntoSeries, IsSorted, Series}; +use polars_core::utils::operation_exceeded_idxsize_msg; +use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_expr::prelude::PhysicalExpr; +use polars_io::cloud::CloudOptions; +use polars_io::predicates::PhysicalIoExpr; +use polars_io::prelude::_internal::read_this_row_group; +use polars_io::prelude::{FileMetaData, ParquetOptions}; +use polars_io::utils::byte_source::{ + ByteSource, DynByteSource, DynByteSourceBuilder, MemSliceByteSource, +}; +use polars_io::utils::slice::SplitSlicePosition; +use polars_io::{is_cloud_url, RowIndex}; +use polars_parquet::read::RowGroupMetaData; +use polars_plan::plans::hive::HivePartitions; +use polars_plan::plans::FileInfo; +use polars_plan::prelude::FileScanOptions; +use polars_utils::mmap::MemSlice; +use polars_utils::pl_str::PlSmallStr; +use polars_utils::slice::GetSaferUnchecked; +use polars_utils::IdxSize; + +use super::{MorselSeq, TaskPriority}; +use crate::async_executor::{self}; +use crate::async_primitives::connector::connector; +use crate::async_primitives::wait_group::{WaitGroup, WaitToken}; +use crate::morsel::get_ideal_morsel_size; +use crate::utils::task_handles_ext; + +type AsyncTaskData = Option<( + Vec>, + async_executor::AbortOnDropHandle>, +)>; + +#[allow(clippy::type_complexity)] +pub struct ParquetSourceNode { + paths: Arc<[PathBuf]>, + file_info: FileInfo, + hive_parts: Option>>, + predicate: Option>, + options: ParquetOptions, + cloud_options: Option, + file_options: FileScanOptions, + // Run-time vars + config: Config, + verbose: bool, + physical_predicate: Option>, + projected_arrow_fields: Arc<[polars_core::prelude::ArrowField]>, + byte_source_builder: DynByteSourceBuilder, + memory_prefetch_func: fn(&[u8]) -> (), + // This permit blocks execution until the first morsel is requested. + morsel_stream_starter: Option>, + // This is behind a Mutex so that we can call `shutdown()` asynchronously. + async_task_data: Arc>, + row_group_decoder: Option>, + is_finished: Arc, +} + +#[allow(clippy::too_many_arguments)] +impl ParquetSourceNode { + pub fn new( + paths: Arc<[PathBuf]>, + file_info: FileInfo, + hive_parts: Option>>, + predicate: Option>, + options: ParquetOptions, + cloud_options: Option, + file_options: FileScanOptions, + ) -> Self { + let verbose = config::verbose(); + + let byte_source_builder = + if is_cloud_url(paths[0].to_str().unwrap()) || config::force_async() { + DynByteSourceBuilder::ObjectStore + } else { + DynByteSourceBuilder::Mmap + }; + let memory_prefetch_func = get_memory_prefetch_func(verbose); + + Self { + paths, + file_info, + hive_parts, + predicate, + options, + cloud_options, + file_options, + + config: Config { + // Initialized later + num_pipelines: 0, + metadata_prefetch_size: 0, + metadata_decode_ahead_size: 0, + row_group_prefetch_size: 0, + }, + verbose, + physical_predicate: None, + projected_arrow_fields: Arc::new([]), + byte_source_builder, + memory_prefetch_func, + + morsel_stream_starter: None, + async_task_data: Arc::new(tokio::sync::Mutex::new(None)), + row_group_decoder: None, + is_finished: Arc::new(AtomicBool::new(false)), + } + } +} + +mod compute_node_impl { + + use std::sync::Arc; + + use polars_expr::prelude::phys_expr_to_io_expr; + + use super::super::compute_node_prelude::*; + use super::{Config, ParquetSourceNode}; + use crate::morsel::SourceToken; + + impl ComputeNode for ParquetSourceNode { + fn name(&self) -> &str { + "parquet_source" + } + + fn initialize(&mut self, num_pipelines: usize) { + self.config = { + let metadata_prefetch_size = polars_core::config::get_file_prefetch_size(); + // Limit metadata decode to the number of threads. + let metadata_decode_ahead_size = + (metadata_prefetch_size / 2).min(1 + num_pipelines).max(1); + let row_group_prefetch_size = polars_core::config::get_rg_prefetch_size(); + + Config { + num_pipelines, + metadata_prefetch_size, + metadata_decode_ahead_size, + row_group_prefetch_size, + } + }; + + if self.verbose { + eprintln!("[ParquetSource]: {:?}", &self.config); + } + + self.init_projected_arrow_fields(); + self.physical_predicate = self.predicate.clone().map(phys_expr_to_io_expr); + + let (raw_morsel_receivers, morsel_stream_task_handle) = self.init_raw_morsel_stream(); + + self.async_task_data + .try_lock() + .unwrap() + .replace((raw_morsel_receivers, morsel_stream_task_handle)); + + let row_group_decoder = self.init_row_group_decoder(); + self.row_group_decoder = Some(Arc::new(row_group_decoder)); + } + + fn update_state( + &mut self, + recv: &mut [PortState], + send: &mut [PortState], + ) -> PolarsResult<()> { + use std::sync::atomic::Ordering; + + assert!(recv.is_empty()); + assert_eq!(send.len(), 1); + + if self.is_finished.load(Ordering::Relaxed) { + send[0] = PortState::Done; + assert!( + self.async_task_data.try_lock().unwrap().is_none(), + "should have already been shut down" + ); + } else if send[0] == PortState::Done { + { + // Early shutdown - our port state was set to `Done` by the downstream nodes. + self.shutdown_in_background(); + }; + self.is_finished.store(true, Ordering::Relaxed); + } else { + send[0] = PortState::Ready + } + + Ok(()) + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv: &mut [Option>], + send: &mut [Option>], + _state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + use std::sync::atomic::Ordering; + + assert!(recv.is_empty()); + assert_eq!(send.len(), 1); + assert!(!self.is_finished.load(Ordering::Relaxed)); + + let morsel_senders = send[0].take().unwrap().parallel(); + + let mut async_task_data_guard = self.async_task_data.try_lock().unwrap(); + let (raw_morsel_receivers, _) = async_task_data_guard.as_mut().unwrap(); + + assert_eq!(raw_morsel_receivers.len(), morsel_senders.len()); + + if let Some(v) = self.morsel_stream_starter.take() { + v.send(()).unwrap(); + } + let is_finished = self.is_finished.clone(); + + let task_handles = raw_morsel_receivers + .drain(..) + .zip(morsel_senders) + .map(|(mut raw_morsel_rx, mut morsel_tx)| { + let is_finished = is_finished.clone(); + + scope.spawn_task(TaskPriority::Low, async move { + let source_token = SourceToken::new(); + loop { + let Ok((df, morsel_seq, wait_token)) = raw_morsel_rx.recv().await + else { + is_finished.store(true, Ordering::Relaxed); + break; + }; + + let mut morsel = Morsel::new(df, morsel_seq, source_token.clone()); + morsel.set_consume_token(wait_token); + + if morsel_tx.send(morsel).await.is_err() { + break; + } + + if source_token.stop_requested() { + break; + } + } + + raw_morsel_rx + }) + }) + .collect::>(); + + drop(async_task_data_guard); + + let async_task_data = self.async_task_data.clone(); + + join_handles.push(scope.spawn_task(TaskPriority::Low, async move { + { + let mut async_task_data_guard = async_task_data.try_lock().unwrap(); + let (raw_morsel_receivers, _) = async_task_data_guard.as_mut().unwrap(); + + for handle in task_handles { + raw_morsel_receivers.push(handle.await); + } + } + + if self.is_finished.load(Ordering::Relaxed) { + self.shutdown().await?; + } + + Ok(()) + })) + } + } +} + +impl ParquetSourceNode { + /// # Panics + /// Panics if called more than once. + async fn shutdown_impl( + async_task_data: Arc>, + verbose: bool, + ) -> PolarsResult<()> { + if verbose { + eprintln!("[ParquetSource]: Shutting down"); + } + + let (mut raw_morsel_receivers, morsel_stream_task_handle) = + async_task_data.try_lock().unwrap().take().unwrap(); + + raw_morsel_receivers.clear(); + // Join on the producer handle to catch errors/panics. + // Safety + // * We dropped the receivers on the line above + // * This function is only called once. + morsel_stream_task_handle.await + } + + fn shutdown(&self) -> impl Future> { + if self.verbose { + eprintln!("[ParquetSource]: Shutdown via `shutdown()`"); + } + Self::shutdown_impl(self.async_task_data.clone(), self.verbose) + } + + /// Spawns a task to shut down the source node to avoid blocking the current thread. This is + /// usually called when data is no longer needed from the source node, as such it does not + /// propagate any (non-critical) errors. If on the other hand the source node does not provide + /// more data when requested, then it is more suitable to call [`Self::shutdown`], as it returns + /// a result that can be used to distinguish between whether the data stream stopped due to an + /// error or EOF. + fn shutdown_in_background(&self) { + if self.verbose { + eprintln!("[ParquetSource]: Shutdown via `shutdown_in_background()`"); + } + let async_task_data = self.async_task_data.clone(); + polars_io::pl_async::get_runtime() + .spawn(Self::shutdown_impl(async_task_data, self.verbose)); + } + + /// Constructs the task that provides a morsel stream. + #[allow(clippy::type_complexity)] + fn init_raw_morsel_stream( + &mut self, + ) -> ( + Vec>, + async_executor::AbortOnDropHandle>, + ) { + let verbose = self.verbose; + + let use_statistics = self.options.use_statistics; + + let (mut raw_morsel_senders, raw_morsel_receivers): (Vec<_>, Vec<_>) = + (0..self.config.num_pipelines).map(|_| connector()).unzip(); + + if let Some((_, 0)) = self.file_options.slice { + return ( + raw_morsel_receivers, + async_executor::AbortOnDropHandle::new(async_executor::spawn( + TaskPriority::Low, + std::future::ready(Ok(())), + )), + ); + } + + let reader_schema = self + .file_info + .reader_schema + .as_ref() + .unwrap() + .as_ref() + .unwrap_left() + .clone(); + + let (normalized_slice_oneshot_rx, metadata_rx, metadata_task_handle) = + self.init_metadata_fetcher(); + + let num_pipelines = self.config.num_pipelines; + let row_group_prefetch_size = self.config.row_group_prefetch_size; + let projection = self.file_options.with_columns.clone(); + assert_eq!(self.physical_predicate.is_some(), self.predicate.is_some()); + let predicate = self.physical_predicate.clone(); + let memory_prefetch_func = self.memory_prefetch_func; + + let mut row_group_data_fetcher = RowGroupDataFetcher { + metadata_rx, + use_statistics, + verbose, + reader_schema, + projection, + predicate, + slice_range: None, // Initialized later + memory_prefetch_func, + current_path_index: 0, + current_byte_source: Default::default(), + current_row_groups: Default::default(), + current_row_group_idx: 0, + current_max_row_group_height: 0, + current_row_offset: 0, + current_shared_file_state: Default::default(), + }; + + let row_group_decoder = self.init_row_group_decoder(); + let row_group_decoder = Arc::new(row_group_decoder); + + // Processes row group metadata and spawns I/O tasks to fetch row group data. This is + // currently spawned onto the CPU runtime as it does not directly make any async I/O calls, + // but instead it potentially performs predicate/slice evaluation on metadata. If we observe + // that under heavy CPU load scenarios the I/O throughput drops due to this task not being + // scheduled we can change it to be a high priority task. + let morsel_stream_task_handle = async_executor::spawn(TaskPriority::Low, async move { + let slice_range = { + let Ok(slice) = normalized_slice_oneshot_rx.await else { + // If we are here then the producer probably errored. + drop(row_group_data_fetcher); + return metadata_task_handle.await.unwrap(); + }; + + slice.map(|(offset, len)| offset..offset + len) + }; + + row_group_data_fetcher.slice_range = slice_range; + + // Pins a wait group to a channel index. + struct IndexedWaitGroup { + index: usize, + wait_group: WaitGroup, + } + + impl IndexedWaitGroup { + async fn wait(self) -> Self { + self.wait_group.wait().await; + self + } + } + + // Ensure proper backpressure by only polling the buffered iterator when a wait group + // is free. + let mut wait_groups = (0..num_pipelines) + .map(|index| { + let wait_group = WaitGroup::default(); + { + let _prime_this_wait_group = wait_group.token(); + } + IndexedWaitGroup { + index, + wait_group: WaitGroup::default(), + } + .wait() + }) + .collect::>(); + + let mut df_stream = row_group_data_fetcher + .into_stream() + .map(|x| async { + match x { + Ok(handle) => handle.await, + Err(e) => Err(e), + } + }) + .buffered(row_group_prefetch_size) + .map(|x| async { + let row_group_decoder = row_group_decoder.clone(); + + match x { + Ok(row_group_data) => { + async_executor::spawn(TaskPriority::Low, async move { + row_group_decoder.row_group_data_to_df(row_group_data).await + }) + .await + }, + Err(e) => Err(e), + } + }) + .buffered( + // Because we are using an ordered buffer, we may suffer from head-of-line blocking, + // so we add a small amount of buffer. + num_pipelines + 4, + ); + + let morsel_seq_ref = &mut MorselSeq::default(); + let mut dfs = vec![].into_iter(); + + 'main: loop { + let Some(mut indexed_wait_group) = wait_groups.next().await else { + break; + }; + + if dfs.len() == 0 { + let Some(v) = df_stream.next().await else { + break; + }; + + let v = v?; + assert!(!v.is_empty()); + + dfs = v.into_iter(); + } + + let mut df = dfs.next().unwrap(); + let morsel_seq = *morsel_seq_ref; + *morsel_seq_ref = morsel_seq.successor(); + + loop { + use crate::async_primitives::connector::SendError; + + let channel_index = indexed_wait_group.index; + let wait_token = indexed_wait_group.wait_group.token(); + + match raw_morsel_senders[channel_index].try_send((df, morsel_seq, wait_token)) { + Ok(_) => { + wait_groups.push(indexed_wait_group.wait()); + break; + }, + Err(SendError::Closed(v)) => { + // The channel assigned to this wait group has been closed, so we will not + // add it back to the list of wait groups, and we will try to send this + // across another channel. + df = v.0 + }, + Err(SendError::Full(_)) => unreachable!(), + } + + let Some(v) = wait_groups.next().await else { + // All channels have closed + break 'main; + }; + + indexed_wait_group = v; + } + } + + // Join on the producer handle to catch errors/panics. + drop(df_stream); + metadata_task_handle.await.unwrap() + }); + + let morsel_stream_task_handle = + async_executor::AbortOnDropHandle::new(morsel_stream_task_handle); + + (raw_morsel_receivers, morsel_stream_task_handle) + } + + /// Constructs the task that fetches file metadata. + /// Note: This must be called AFTER `self.projected_arrow_fields` has been initialized. + /// + /// TODO: During IR conversion the metadata of the first file is already downloaded - see if + /// we can find a way to re-use it. + #[allow(clippy::type_complexity)] + fn init_metadata_fetcher( + &mut self, + ) -> ( + tokio::sync::oneshot::Receiver>, + crate::async_primitives::connector::Receiver<( + usize, + usize, + Arc, + FileMetaData, + usize, + )>, + task_handles_ext::AbortOnDropHandle>, + ) { + let verbose = self.verbose; + let io_runtime = polars_io::pl_async::get_runtime(); + + assert!( + !self.projected_arrow_fields.is_empty() + || self.file_options.with_columns.as_deref() == Some(&[]) + ); + let projected_arrow_fields = self.projected_arrow_fields.clone(); + let needs_max_row_group_height_calc = + self.file_options.include_file_paths.is_some() || self.hive_parts.is_some(); + + let (normalized_slice_oneshot_tx, normalized_slice_oneshot_rx) = + tokio::sync::oneshot::channel(); + let (mut metadata_tx, metadata_rx) = connector(); + + let byte_source_builder = self.byte_source_builder.clone(); + + if self.verbose { + eprintln!( + "[ParquetSource]: Byte source builder: {:?}", + &byte_source_builder + ); + } + + let fetch_metadata_bytes_for_path_index = { + let paths = &self.paths; + let cloud_options = Arc::new(self.cloud_options.clone()); + + let paths = paths.clone(); + let cloud_options = cloud_options.clone(); + let byte_source_builder = byte_source_builder.clone(); + + move |path_idx: usize| { + let paths = paths.clone(); + let cloud_options = cloud_options.clone(); + let byte_source_builder = byte_source_builder.clone(); + + let handle = io_runtime.spawn(async move { + let mut byte_source = Arc::new( + byte_source_builder + .try_build_from_path( + paths[path_idx].to_str().unwrap(), + cloud_options.as_ref().as_ref(), + ) + .await?, + ); + let (metadata_bytes, maybe_full_bytes) = + read_parquet_metadata_bytes(byte_source.as_ref(), verbose).await?; + + if let Some(v) = maybe_full_bytes { + if !matches!(byte_source.as_ref(), DynByteSource::MemSlice(_)) { + if verbose { + eprintln!( + "[ParquetSource]: Parquet file was fully fetched during \ + metadata read ({} bytes).", + v.len(), + ); + } + + byte_source = Arc::new(DynByteSource::from(MemSliceByteSource(v))) + } + } + + PolarsResult::Ok((path_idx, byte_source, metadata_bytes)) + }); + + let handle = task_handles_ext::AbortOnDropHandle(handle); + + std::future::ready(handle) + } + }; + + let process_metadata_bytes = { + move |handle: task_handles_ext::AbortOnDropHandle< + PolarsResult<(usize, Arc, MemSlice)>, + >| { + let projected_arrow_fields = projected_arrow_fields.clone(); + // Run on CPU runtime - metadata deserialization is expensive, especially + // for very wide tables. + let handle = async_executor::spawn(TaskPriority::Low, async move { + let (path_index, byte_source, metadata_bytes) = handle.await.unwrap()?; + + let metadata = polars_parquet::parquet::read::deserialize_metadata( + metadata_bytes.as_ref(), + metadata_bytes.len() * 2 + 1024, + )?; + + ensure_metadata_has_projected_fields( + projected_arrow_fields.as_ref(), + &metadata, + )?; + + let file_max_row_group_height = if needs_max_row_group_height_calc { + metadata + .row_groups + .iter() + .map(|x| x.num_rows()) + .max() + .unwrap_or(0) + } else { + 0 + }; + + PolarsResult::Ok((path_index, byte_source, metadata, file_max_row_group_height)) + }); + + async_executor::AbortOnDropHandle::new(handle) + } + }; + + let metadata_prefetch_size = self.config.metadata_prefetch_size; + let metadata_decode_ahead_size = self.config.metadata_decode_ahead_size; + + let (start_tx, start_rx) = tokio::sync::oneshot::channel(); + self.morsel_stream_starter = Some(start_tx); + + let metadata_task_handle = if self + .file_options + .slice + .map(|(offset, _)| offset >= 0) + .unwrap_or(true) + { + normalized_slice_oneshot_tx + .send( + self.file_options + .slice + .map(|(offset, len)| (offset as usize, len)), + ) + .unwrap(); + + // Safety: `offset + len` does not overflow. + let slice_range = self + .file_options + .slice + .map(|(offset, len)| offset as usize..offset as usize + len); + + let mut metadata_stream = futures::stream::iter(0..self.paths.len()) + .map(fetch_metadata_bytes_for_path_index) + .buffered(metadata_prefetch_size) + .map(process_metadata_bytes) + .buffered(metadata_decode_ahead_size); + + let paths = self.paths.clone(); + + // We need to be able to both stop early as well as skip values, which is easier to do + // using a custom task instead of futures::stream + io_runtime.spawn(async move { + let current_row_offset_ref = &mut 0usize; + let current_path_index_ref = &mut 0usize; + + if start_rx.await.is_err() { + return Ok(()); + } + + if verbose { + eprintln!("[ParquetSource]: Starting data fetch") + } + + loop { + let current_path_index = *current_path_index_ref; + *current_path_index_ref += 1; + + let Some(v) = metadata_stream.next().await else { + break; + }; + + let (path_index, byte_source, metadata, file_max_row_group_height) = v + .map_err(|err| { + err.wrap_msg(|msg| { + format!( + "error at path (index: {}, path: {}): {}", + current_path_index, + paths[current_path_index].to_str().unwrap(), + msg + ) + }) + })?; + + assert_eq!(path_index, current_path_index); + + let current_row_offset = *current_row_offset_ref; + *current_row_offset_ref = current_row_offset.saturating_add(metadata.num_rows); + + if let Some(slice_range) = slice_range.clone() { + match SplitSlicePosition::split_slice_at_file( + current_row_offset, + metadata.num_rows, + slice_range, + ) { + SplitSlicePosition::Before => { + if verbose { + eprintln!( + "[ParquetSource]: Slice pushdown: \ + Skipped file at index {} ({} rows)", + current_path_index, metadata.num_rows + ); + } + continue; + }, + SplitSlicePosition::After => unreachable!(), + SplitSlicePosition::Overlapping(..) => {}, + }; + }; + + if metadata_tx + .send(( + path_index, + current_row_offset, + byte_source, + metadata, + file_max_row_group_height, + )) + .await + .is_err() + { + break; + } + + if let Some(slice_range) = slice_range.as_ref() { + if *current_row_offset_ref >= slice_range.end { + if verbose { + eprintln!( + "[ParquetSource]: Slice pushdown: \ + Stopped reading at file at index {} \ + (remaining {} files will not be read)", + current_path_index, + paths.len() - current_path_index - 1, + ); + } + break; + } + }; + } + + Ok(()) + }) + } else { + // Walk the files in reverse to translate the slice into a positive offset. + let slice = self.file_options.slice.unwrap(); + let slice_start_as_n_from_end = -slice.0 as usize; + + let mut metadata_stream = futures::stream::iter((0..self.paths.len()).rev()) + .map(fetch_metadata_bytes_for_path_index) + .buffered(metadata_prefetch_size) + .map(process_metadata_bytes) + .buffered(metadata_decode_ahead_size); + + // Note: + // * We want to wait until the first morsel is requested before starting this + let init_negative_slice_and_metadata = async move { + let mut processed_metadata_rev = vec![]; + let mut cum_rows = 0; + + while let Some(v) = metadata_stream.next().await { + let v = v?; + let (_, _, metadata, _) = &v; + cum_rows += metadata.num_rows; + processed_metadata_rev.push(v); + + if cum_rows >= slice_start_as_n_from_end { + break; + } + } + + let (start, len) = if slice_start_as_n_from_end > cum_rows { + // We need to trim the slice, e.g. SLICE[offset: -100, len: 75] on a file of 50 + // rows should only give the first 25 rows. + let first_file_position = slice_start_as_n_from_end - cum_rows; + (0, slice.1.saturating_sub(first_file_position)) + } else { + (cum_rows - slice_start_as_n_from_end, slice.1) + }; + + if len == 0 { + processed_metadata_rev.clear(); + } + + normalized_slice_oneshot_tx + .send(Some((start, len))) + .unwrap(); + + let slice_range = start..(start + len); + + PolarsResult::Ok((slice_range, processed_metadata_rev, cum_rows)) + }; + + let path_count = self.paths.len(); + + io_runtime.spawn(async move { + if start_rx.await.is_err() { + return Ok(()); + } + + if verbose { + eprintln!("[ParquetSource]: Starting data fetch (negative slice)") + } + + let (slice_range, processed_metadata_rev, cum_rows) = + async_executor::AbortOnDropHandle::new(async_executor::spawn( + TaskPriority::Low, + init_negative_slice_and_metadata, + )) + .await?; + + if verbose { + if let Some((path_index, ..)) = processed_metadata_rev.last() { + eprintln!( + "[ParquetSource]: Slice pushdown: Negatively-offsetted slice {:?} \ + begins at file index {}, translated to {:?}", + slice, path_index, slice_range + ); + } else { + eprintln!( + "[ParquetSource]: Slice pushdown: Negatively-offsetted slice {:?} \ + skipped all files ({} files containing {} rows)", + slice, path_count, cum_rows + ) + } + } + + let metadata_iter = processed_metadata_rev.into_iter().rev(); + let current_row_offset_ref = &mut 0usize; + + for (current_path_index, byte_source, metadata, file_max_row_group_height) in + metadata_iter + { + let current_row_offset = *current_row_offset_ref; + *current_row_offset_ref = current_row_offset.saturating_add(metadata.num_rows); + + assert!(matches!( + SplitSlicePosition::split_slice_at_file( + current_row_offset, + metadata.num_rows, + slice_range.clone(), + ), + SplitSlicePosition::Overlapping(..) + )); + + if metadata_tx + .send(( + current_path_index, + current_row_offset, + byte_source, + metadata, + file_max_row_group_height, + )) + .await + .is_err() + { + break; + } + + if *current_row_offset_ref >= slice_range.end { + if verbose { + eprintln!( + "[ParquetSource]: Slice pushdown: \ + Stopped reading at file at index {} \ + (remaining {} files will not be read)", + current_path_index, + path_count - current_path_index - 1, + ); + } + break; + } + } + + Ok(()) + }) + }; + + let metadata_task_handle = task_handles_ext::AbortOnDropHandle(metadata_task_handle); + + ( + normalized_slice_oneshot_rx, + metadata_rx, + metadata_task_handle, + ) + } + + /// Creates a `RowGroupDecoder` that turns `RowGroupData` into DataFrames. + /// This must be called AFTER the following have been initialized: + /// * `self.projected_arrow_fields` + /// * `self.physical_predicate` + fn init_row_group_decoder(&self) -> RowGroupDecoder { + assert!( + !self.projected_arrow_fields.is_empty() + || self.file_options.with_columns.as_deref() == Some(&[]) + ); + assert_eq!(self.predicate.is_some(), self.physical_predicate.is_some()); + + let paths = self.paths.clone(); + let hive_partitions = self.hive_parts.clone(); + let hive_partitions_width = hive_partitions + .as_deref() + .map(|x| x[0].get_statistics().column_stats().len()) + .unwrap_or(0); + let include_file_paths = self.file_options.include_file_paths.clone(); + let projected_arrow_fields = self.projected_arrow_fields.clone(); + let row_index = self.file_options.row_index.clone(); + let physical_predicate = self.physical_predicate.clone(); + let ideal_morsel_size = get_ideal_morsel_size(); + + RowGroupDecoder { + paths, + hive_partitions, + hive_partitions_width, + include_file_paths, + projected_arrow_fields, + row_index, + physical_predicate, + ideal_morsel_size, + } + } + + fn init_projected_arrow_fields(&mut self) { + let reader_schema = self + .file_info + .reader_schema + .as_ref() + .unwrap() + .as_ref() + .unwrap_left() + .clone(); + + self.projected_arrow_fields = + if let Some(columns) = self.file_options.with_columns.as_deref() { + columns + .iter() + .map(|x| reader_schema.get(x).unwrap().clone()) + .collect() + } else { + reader_schema.iter_values().cloned().collect() + }; + + if self.verbose { + eprintln!( + "[ParquetSource]: {} columns to be projected from {} files", + self.projected_arrow_fields.len(), + self.paths.len(), + ); + } + } +} + +#[derive(Debug)] +struct Config { + num_pipelines: usize, + /// Number of files to pre-fetch metadata for concurrently + metadata_prefetch_size: usize, + /// Number of files to decode metadata for in parallel in advance + metadata_decode_ahead_size: usize, + /// Number of row groups to pre-fetch concurrently, this can be across files + row_group_prefetch_size: usize, +} + +/// Represents byte-data that can be transformed into a DataFrame after some computation. +struct RowGroupData { + byte_source: FetchedBytes, + path_index: usize, + row_offset: usize, + slice: Option<(usize, usize)>, + file_max_row_group_height: usize, + row_group_metadata: RowGroupMetaData, + shared_file_state: Arc>, +} + +struct RowGroupDataFetcher { + metadata_rx: crate::async_primitives::connector::Receiver<( + usize, + usize, + Arc, + FileMetaData, + usize, + )>, + use_statistics: bool, + verbose: bool, + reader_schema: Arc, + projection: Option>, + predicate: Option>, + slice_range: Option>, + memory_prefetch_func: fn(&[u8]) -> (), + current_path_index: usize, + current_byte_source: Arc, + current_row_groups: std::vec::IntoIter, + current_row_group_idx: usize, + current_max_row_group_height: usize, + current_row_offset: usize, + current_shared_file_state: Arc>, +} + +impl RowGroupDataFetcher { + fn into_stream(self) -> RowGroupDataStream { + RowGroupDataStream::new(self) + } + + async fn init_next_file_state(&mut self) -> bool { + let Ok((path_index, row_offset, byte_source, metadata, file_max_row_group_height)) = + self.metadata_rx.recv().await + else { + return false; + }; + + self.current_path_index = path_index; + self.current_byte_source = byte_source; + self.current_max_row_group_height = file_max_row_group_height; + // The metadata task also sends a row offset to start counting from as it may skip files + // during slice pushdown. + self.current_row_offset = row_offset; + self.current_row_group_idx = 0; + self.current_row_groups = metadata.row_groups.into_iter(); + self.current_shared_file_state = Default::default(); + + true + } + + async fn next( + &mut self, + ) -> Option>>> { + 'main: loop { + for row_group_metadata in self.current_row_groups.by_ref() { + let current_row_offset = self.current_row_offset; + let current_row_group_idx = self.current_row_group_idx; + + let num_rows = row_group_metadata.num_rows(); + + self.current_row_offset = current_row_offset.saturating_add(num_rows); + self.current_row_group_idx += 1; + + if self.use_statistics + && !match read_this_row_group( + self.predicate.as_deref(), + &row_group_metadata, + self.reader_schema.as_ref(), + ) { + Ok(v) => v, + Err(e) => return Some(Err(e)), + } + { + if self.verbose { + eprintln!( + "[ParquetSource]: Predicate pushdown: \ + Skipped row group {} in file {} ({} rows)", + current_row_group_idx, self.current_path_index, num_rows + ); + } + continue; + } + + if num_rows > IdxSize::MAX as usize { + let msg = operation_exceeded_idxsize_msg( + format!("number of rows in row group ({})", num_rows).as_str(), + ); + return Some(Err(polars_err!(ComputeError: msg))); + } + + let slice = if let Some(slice_range) = self.slice_range.clone() { + let (offset, len) = match SplitSlicePosition::split_slice_at_file( + current_row_offset, + num_rows, + slice_range, + ) { + SplitSlicePosition::Before => { + if self.verbose { + eprintln!( + "[ParquetSource]: Slice pushdown: \ + Skipped row group {} in file {} ({} rows)", + current_row_group_idx, self.current_path_index, num_rows + ); + } + continue; + }, + SplitSlicePosition::After => { + if self.verbose { + eprintln!( + "[ParquetSource]: Slice pushdown: \ + Stop at row group {} in file {} \ + (remaining {} row groups will not be read)", + current_row_group_idx, + self.current_path_index, + self.current_row_groups.len(), + ); + }; + break 'main; + }, + SplitSlicePosition::Overlapping(offset, len) => (offset, len), + }; + + Some((offset, len)) + } else { + None + }; + + let current_byte_source = self.current_byte_source.clone(); + let projection = self.projection.clone(); + let current_shared_file_state = self.current_shared_file_state.clone(); + let memory_prefetch_func = self.memory_prefetch_func; + let io_runtime = polars_io::pl_async::get_runtime(); + let current_path_index = self.current_path_index; + let current_max_row_group_height = self.current_max_row_group_height; + + // Push calculation of byte ranges to a task to run in parallel, as it can be + // expensive for very wide tables and projections. + let handle = async_executor::spawn(TaskPriority::Low, async move { + let byte_source = if let DynByteSource::MemSlice(mem_slice) = + current_byte_source.as_ref() + { + // Skip byte range calculation for `no_prefetch`. + if memory_prefetch_func as usize != mem_prefetch_funcs::no_prefetch as usize + { + let slice = mem_slice.0.as_ref(); + + if let Some(columns) = projection.as_ref() { + for range in get_row_group_byte_ranges_for_projection( + &row_group_metadata, + columns.as_ref(), + ) { + memory_prefetch_func(unsafe { + slice.get_unchecked_release(range) + }) + } + } else { + let mut iter = get_row_group_byte_ranges(&row_group_metadata); + let first = iter.next().unwrap(); + let range = + iter.fold(first, |l, r| l.start.min(r.start)..l.end.max(r.end)); + + memory_prefetch_func(unsafe { slice.get_unchecked_release(range) }) + }; + } + + // We have a mmapped or in-memory slice representing the entire + // file that can be sliced directly, so we can skip the byte-range + // calculations and HashMap allocation. + let mem_slice = mem_slice.0.clone(); + FetchedBytes::MemSlice { + offset: 0, + mem_slice, + } + } else if let Some(columns) = projection.as_ref() { + let ranges = get_row_group_byte_ranges_for_projection( + &row_group_metadata, + columns.as_ref(), + ) + .collect::>(); + + let bytes = { + let ranges_2 = ranges.clone(); + task_handles_ext::AbortOnDropHandle(io_runtime.spawn(async move { + current_byte_source.get_ranges(ranges_2.as_ref()).await + })) + .await + .unwrap()? + }; + + assert_eq!(bytes.len(), ranges.len()); + + let mut bytes_map = PlHashMap::with_capacity(ranges.len()); + + for (range, bytes) in ranges.iter().zip(bytes) { + memory_prefetch_func(bytes.as_ref()); + let v = bytes_map.insert(range.start, bytes); + debug_assert!(v.is_none(), "duplicate range start {}", range.start); + } + + FetchedBytes::BytesMap(bytes_map) + } else { + // We have a dedicated code-path for a full projection that performs a + // single range request for the entire row group. During testing this + // provided much higher throughput from cloud than making multiple range + // request with `get_ranges()`. + let mut iter = get_row_group_byte_ranges(&row_group_metadata); + let mut ranges = Vec::with_capacity(iter.len()); + let first = iter.next().unwrap(); + ranges.push(first.clone()); + let full_range = iter.fold(first, |l, r| { + ranges.push(r.clone()); + l.start.min(r.start)..l.end.max(r.end) + }); + + let mem_slice = { + let full_range_2 = full_range.clone(); + task_handles_ext::AbortOnDropHandle(io_runtime.spawn(async move { + current_byte_source.get_range(full_range_2).await + })) + .await + .unwrap()? + }; + + FetchedBytes::MemSlice { + offset: full_range.start, + mem_slice, + } + }; + + PolarsResult::Ok(RowGroupData { + byte_source, + path_index: current_path_index, + row_offset: current_row_offset, + slice, + file_max_row_group_height: current_max_row_group_height, + row_group_metadata, + shared_file_state: current_shared_file_state.clone(), + }) + }); + + let handle = async_executor::AbortOnDropHandle::new(handle); + return Some(Ok(handle)); + } + + // Initialize state to the next file. + if !self.init_next_file_state().await { + break; + } + } + + None + } +} + +enum FetchedBytes { + MemSlice { mem_slice: MemSlice, offset: usize }, + BytesMap(PlHashMap), +} + +impl FetchedBytes { + fn get_range(&self, range: std::ops::Range) -> MemSlice { + match self { + Self::MemSlice { mem_slice, offset } => { + let offset = *offset; + debug_assert!(range.start >= offset); + mem_slice.slice(range.start - offset..range.end - offset) + }, + Self::BytesMap(v) => { + let v = v.get(&range.start).unwrap(); + debug_assert_eq!(v.len(), range.len()); + v.clone() + }, + } + } +} + +#[rustfmt::skip] +type RowGroupDataStreamFut = std::pin::Pin , + Option < + PolarsResult < + async_executor::AbortOnDropHandle < + PolarsResult < + RowGroupData > > > > + ) + > + Send +>>; + +struct RowGroupDataStream { + current_future: RowGroupDataStreamFut, +} + +impl RowGroupDataStream { + fn new(row_group_data_fetcher: RowGroupDataFetcher) -> Self { + // [`RowGroupDataFetcher`] is a big struct, so we Box it once here to avoid boxing it on + // every `next()` call. + let current_future = Self::call_next_owned(Box::new(row_group_data_fetcher)); + Self { current_future } + } + + fn call_next_owned( + mut row_group_data_fetcher: Box, + ) -> RowGroupDataStreamFut { + Box::pin(async move { + let out = row_group_data_fetcher.next().await; + (row_group_data_fetcher, out) + }) + } +} + +impl futures::stream::Stream for RowGroupDataStream { + type Item = PolarsResult>>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::pin::Pin; + use std::task::Poll; + + match Pin::new(&mut self.current_future.as_mut()).poll(cx) { + Poll::Ready((row_group_data_fetcher, out)) => { + if out.is_some() { + self.current_future = Self::call_next_owned(row_group_data_fetcher); + } + + Poll::Ready(out) + }, + Poll::Pending => Poll::Pending, + } + } +} + +/// State shared across row groups for a single file. +struct SharedFileState { + path_index: usize, + hive_series: Vec, + file_path_series: Option, +} + +/// Turns row group data into DataFrames. +struct RowGroupDecoder { + paths: Arc<[PathBuf]>, + hive_partitions: Option>>, + hive_partitions_width: usize, + include_file_paths: Option, + projected_arrow_fields: Arc<[polars_core::prelude::ArrowField]>, + row_index: Option, + physical_predicate: Option>, + ideal_morsel_size: usize, +} + +impl RowGroupDecoder { + async fn row_group_data_to_df( + &self, + row_group_data: RowGroupData, + ) -> PolarsResult> { + let row_group_data = Arc::new(row_group_data); + + let out_width = self.row_index.is_some() as usize + + self.projected_arrow_fields.len() + + self.hive_partitions_width + + self.include_file_paths.is_some() as usize; + + let mut out_columns = Vec::with_capacity(out_width); + + if self.row_index.is_some() { + // Add a placeholder so that we don't have to shift the entire vec + // later. + out_columns.push(Series::default()); + } + + let slice_range = row_group_data + .slice + .map(|(offset, len)| offset..offset + len) + .unwrap_or(0..row_group_data.row_group_metadata.num_rows()); + + let projected_arrow_fields = &self.projected_arrow_fields; + let projected_arrow_fields = projected_arrow_fields.clone(); + + let row_group_data_2 = row_group_data.clone(); + let slice_range_2 = slice_range.clone(); + + // Minimum number of values to amortize the overhead of spawning tasks. + // This value is arbitrarily chosen. + const VALUES_PER_THREAD: usize = 16_777_216; + let n_rows = row_group_data.row_group_metadata.num_rows(); + let cols_per_task = 1 + VALUES_PER_THREAD / n_rows; + + let decode_fut_iter = (0..self.projected_arrow_fields.len()) + .step_by(cols_per_task) + .map(move |offset| { + let row_group_data = row_group_data_2.clone(); + let slice_range = slice_range_2.clone(); + let projected_arrow_fields = projected_arrow_fields.clone(); + + async move { + (offset + ..offset + .saturating_add(cols_per_task) + .min(projected_arrow_fields.len())) + .map(|i| { + let arrow_field = projected_arrow_fields[i].clone(); + + let columns_to_deserialize = row_group_data + .row_group_metadata + .columns_under_root_iter(&arrow_field.name) + .map(|col_md| { + let byte_range = col_md.byte_range(); + + ( + col_md, + row_group_data.byte_source.get_range( + byte_range.start as usize..byte_range.end as usize, + ), + ) + }) + .collect::>(); + + assert!( + slice_range.end <= row_group_data.row_group_metadata.num_rows() + ); + + let array = polars_io::prelude::_internal::to_deserializer( + columns_to_deserialize, + arrow_field.clone(), + Some(polars_parquet::read::Filter::Range(slice_range.clone())), + )?; + + let series = Series::try_from((&arrow_field, array))?; + + // TODO: Also load in the metadata. + + PolarsResult::Ok(series) + }) + .collect::>>() + } + }); + + if decode_fut_iter.len() > 1 { + for handle in decode_fut_iter.map(|fut| { + async_executor::AbortOnDropHandle::new(async_executor::spawn( + TaskPriority::Low, + fut, + )) + }) { + out_columns.extend(handle.await?); + } + } else { + for fut in decode_fut_iter { + out_columns.extend(fut.await?); + } + } + + let projection_height = if self.projected_arrow_fields.is_empty() { + slice_range.len() + } else { + debug_assert!(out_columns.len() > self.row_index.is_some() as usize); + out_columns.last().unwrap().len() + }; + + if let Some(RowIndex { name, offset }) = self.row_index.as_ref() { + let Some(offset) = (|| { + let offset = offset + .checked_add((row_group_data.row_offset + slice_range.start) as IdxSize)?; + offset.checked_add(projection_height as IdxSize)?; + + Some(offset) + })() else { + let msg = format!( + "adding a row index column with offset {} overflows at {} rows", + offset, + row_group_data.row_offset + slice_range.end + ); + polars_bail!(ComputeError: msg) + }; + + // The DataFrame can be empty at this point if no columns were projected from the file, + // so we create the row index column manually instead of using `df.with_row_index` to + // ensure it has the correct number of rows. + let mut ca = IdxCa::from_vec( + name.clone(), + (offset..offset + projection_height as IdxSize).collect(), + ); + ca.set_sorted_flag(IsSorted::Ascending); + + out_columns[0] = ca.into_series(); + } + + let shared_file_state = row_group_data + .shared_file_state + .get_or_init(|| async { + let path_index = row_group_data.path_index; + + let hive_series = if let Some(hp) = self.hive_partitions.as_deref() { + let mut v = hp[path_index].materialize_partition_columns(); + for s in v.iter_mut() { + *s = s.new_from_index(0, row_group_data.file_max_row_group_height); + } + v + } else { + vec![] + }; + + let file_path_series = self.include_file_paths.clone().map(|file_path_col| { + StringChunked::full( + file_path_col, + self.paths[path_index].to_str().unwrap(), + row_group_data.file_max_row_group_height, + ) + .into_series() + }); + + SharedFileState { + path_index, + hive_series, + file_path_series, + } + }) + .await; + + assert_eq!(shared_file_state.path_index, row_group_data.path_index); + + for s in &shared_file_state.hive_series { + debug_assert!(s.len() >= projection_height); + out_columns.push(s.slice(0, projection_height)); + } + + if let Some(file_path_series) = &shared_file_state.file_path_series { + debug_assert!(file_path_series.len() >= projection_height); + out_columns.push(file_path_series.slice(0, projection_height)); + } + + let df = unsafe { DataFrame::new_no_checks(out_columns) }; + + // Re-calculate: A slice may have been applied. + let cols_per_task = 1 + VALUES_PER_THREAD / df.height(); + + let df = if let Some(predicate) = self.physical_predicate.as_deref() { + let mask = predicate.evaluate_io(&df)?; + let mask = mask.bool().unwrap(); + + if cols_per_task <= df.width() { + df._filter_seq(mask)? + } else { + let mask = mask.clone(); + let cols = Arc::new(df.take_columns()); + let mut out_cols = Vec::with_capacity(cols.len()); + + for handle in (0..cols.len()) + .step_by(cols_per_task) + .map(move |offset| { + let cols = cols.clone(); + let mask = mask.clone(); + async move { + cols[offset..offset.saturating_add(cols_per_task).min(cols.len())] + .iter() + .map(|s| s.filter(&mask)) + .collect::>>() + } + }) + .map(|fut| { + async_executor::AbortOnDropHandle::new(async_executor::spawn( + TaskPriority::Low, + fut, + )) + }) + { + out_cols.extend(handle.await?); + } + + unsafe { DataFrame::new_no_checks(out_cols) } + } + } else { + df + }; + + assert_eq!(df.width(), out_width); + + let n_morsels = if df.height() > 3 * self.ideal_morsel_size / 2 { + // num_rows > (1.5 * ideal_morsel_size) + (df.height() / self.ideal_morsel_size).max(2) + } else { + 1 + } as u64; + + if n_morsels == 1 { + return Ok(vec![df]); + } + + let rows_per_morsel = 1 + df.height() / n_morsels as usize; + + let out = (0..i64::try_from(df.height()).unwrap()) + .step_by(rows_per_morsel) + .map(|offset| df.slice(offset, rows_per_morsel)) + .collect::>(); + + Ok(out) + } +} + +/// Read the metadata bytes of a parquet file, does not decode the bytes. If during metadata fetch +/// the bytes of the entire file are loaded, it is returned in the second return value. +async fn read_parquet_metadata_bytes( + byte_source: &DynByteSource, + verbose: bool, +) -> PolarsResult<(MemSlice, Option)> { + use polars_parquet::parquet::error::ParquetError; + use polars_parquet::parquet::PARQUET_MAGIC; + + const FOOTER_HEADER_SIZE: usize = polars_parquet::parquet::FOOTER_SIZE as usize; + + let file_size = byte_source.get_size().await?; + + if file_size < FOOTER_HEADER_SIZE { + return Err(ParquetError::OutOfSpec(format!( + "file size ({}) is less than minimum size required to store parquet footer ({})", + file_size, FOOTER_HEADER_SIZE + )) + .into()); + } + + let estimated_metadata_size = if let DynByteSource::MemSlice(_) = byte_source { + // Mmapped or in-memory, reads are free. + file_size + } else { + (file_size / 2048).clamp(16_384, 131_072).min(file_size) + }; + + let bytes = byte_source + .get_range((file_size - estimated_metadata_size)..file_size) + .await?; + + let footer_header_bytes = bytes.slice((bytes.len() - FOOTER_HEADER_SIZE)..bytes.len()); + + let (v, remaining) = footer_header_bytes.split_at(4); + let footer_size = i32::from_le_bytes(v.try_into().unwrap()); + + if remaining != PARQUET_MAGIC { + return Err(ParquetError::OutOfSpec(format!( + r#"expected parquet magic bytes "{}" in footer, got "{}" instead"#, + std::str::from_utf8(&PARQUET_MAGIC).unwrap(), + String::from_utf8_lossy(remaining) + )) + .into()); + } + + if footer_size < 0 { + return Err(ParquetError::OutOfSpec(format!( + "expected positive footer size, got {} instead", + footer_size + )) + .into()); + } + + let footer_size = footer_size as usize + FOOTER_HEADER_SIZE; + + if file_size < footer_size { + return Err(ParquetError::OutOfSpec(format!( + "file size ({}) is less than the indicated footer size ({})", + file_size, footer_size + )) + .into()); + } + + if bytes.len() < footer_size { + debug_assert!(!matches!(byte_source, DynByteSource::MemSlice(_))); + if verbose { + eprintln!( + "[ParquetSource]: Extra {} bytes need to be fetched for metadata \ + (initial estimate = {}, actual size = {})", + footer_size - estimated_metadata_size, + bytes.len(), + footer_size, + ); + } + + let mut out = Vec::with_capacity(footer_size); + let offset = file_size - footer_size; + let len = footer_size - bytes.len(); + let delta_bytes = byte_source.get_range(offset..(offset + len)).await?; + + debug_assert!(out.capacity() >= delta_bytes.len() + bytes.len()); + + out.extend_from_slice(&delta_bytes); + out.extend_from_slice(&bytes); + + Ok((MemSlice::from_vec(out), None)) + } else { + if verbose && !matches!(byte_source, DynByteSource::MemSlice(_)) { + eprintln!( + "[ParquetSource]: Fetched all bytes for metadata on first try \ + (initial estimate = {}, actual size = {}, excess = {})", + bytes.len(), + footer_size, + estimated_metadata_size - footer_size, + ); + } + + let metadata_bytes = bytes.slice((bytes.len() - footer_size)..bytes.len()); + + if bytes.len() == file_size { + Ok((metadata_bytes, Some(bytes))) + } else { + debug_assert!(!matches!(byte_source, DynByteSource::MemSlice(_))); + let metadata_bytes = if bytes.len() - footer_size >= bytes.len() { + // Re-allocate to drop the excess bytes + MemSlice::from_vec(metadata_bytes.to_vec()) + } else { + metadata_bytes + }; + + Ok((metadata_bytes, None)) + } + } +} + +fn get_row_group_byte_ranges( + row_group_metadata: &RowGroupMetaData, +) -> impl ExactSizeIterator> + '_ { + row_group_metadata + .byte_ranges_iter() + .map(|byte_range| byte_range.start as usize..byte_range.end as usize) +} + +fn get_row_group_byte_ranges_for_projection<'a>( + row_group_metadata: &'a RowGroupMetaData, + columns: &'a [PlSmallStr], +) -> impl Iterator> + 'a { + columns.iter().flat_map(|col_name| { + row_group_metadata + .columns_under_root_iter(col_name) + .map(|col| { + let byte_range = col.byte_range(); + byte_range.start as usize..byte_range.end as usize + }) + }) +} + +/// Ensures that a parquet file has all the necessary columns for a projection with the correct +/// dtype. There are no ordering requirements and extra columns are permitted. +fn ensure_metadata_has_projected_fields( + projected_fields: &[polars_core::prelude::ArrowField], + metadata: &FileMetaData, +) -> PolarsResult<()> { + let schema = polars_parquet::arrow::read::infer_schema(metadata)?; + + // Note: We convert to Polars-native dtypes for timezone normalization. + let mut schema = schema + .into_iter_values() + .map(|x| { + let dtype = DataType::from_arrow(&x.dtype, true); + (x.name, dtype) + }) + .collect::>(); + + for field in projected_fields { + let Some(dtype) = schema.remove(&field.name) else { + polars_bail!(SchemaMismatch: "did not find column: {}", field.name) + }; + + let expected_dtype = DataType::from_arrow(&field.dtype, true); + + if dtype != expected_dtype { + polars_bail!(SchemaMismatch: "data type mismatch for column {}: found: {}, expected: {}", + &field.name, dtype, expected_dtype + ) + } + } + + Ok(()) +} + +fn get_memory_prefetch_func(verbose: bool) -> fn(&[u8]) -> () { + let memory_prefetch_func = match std::env::var("POLARS_MEMORY_PREFETCH").ok().as_deref() { + None => { + // Sequential advice was observed to provide speedups on Linux. + // ref https://github.com/pola-rs/polars/pull/18152#discussion_r1721701965 + #[cfg(target_os = "linux")] + { + mem_prefetch_funcs::madvise_sequential + } + #[cfg(not(target_os = "linux"))] + { + mem_prefetch_funcs::no_prefetch + } + }, + Some("no_prefetch") => mem_prefetch_funcs::no_prefetch, + Some("prefetch_l2") => mem_prefetch_funcs::prefetch_l2, + Some("madvise_sequential") => { + #[cfg(target_family = "unix")] + { + mem_prefetch_funcs::madvise_sequential + } + #[cfg(not(target_family = "unix"))] + { + panic!("POLARS_MEMORY_PREFETCH=madvise_sequential is not supported by this system"); + } + }, + Some("madvise_willneed") => { + #[cfg(target_family = "unix")] + { + mem_prefetch_funcs::madvise_willneed + } + #[cfg(not(target_family = "unix"))] + { + panic!("POLARS_MEMORY_PREFETCH=madvise_willneed is not supported by this system"); + } + }, + Some("madvise_populate_read") => { + #[cfg(target_os = "linux")] + { + mem_prefetch_funcs::madvise_populate_read + } + #[cfg(not(target_os = "linux"))] + { + panic!( + "POLARS_MEMORY_PREFETCH=madvise_populate_read is not supported by this system" + ); + } + }, + Some(v) => panic!("invalid value for POLARS_MEMORY_PREFETCH: {}", v), + }; + + if verbose { + let func_name = match memory_prefetch_func as usize { + v if v == mem_prefetch_funcs::no_prefetch as usize => "no_prefetch", + v if v == mem_prefetch_funcs::prefetch_l2 as usize => "prefetch_l2", + v if v == mem_prefetch_funcs::madvise_sequential as usize => "madvise_sequential", + v if v == mem_prefetch_funcs::madvise_willneed as usize => "madvise_willneed", + v if v == mem_prefetch_funcs::madvise_populate_read as usize => "madvise_populate_read", + _ => unreachable!(), + }; + + eprintln!("[ParquetSource] Memory prefetch function: {}", func_name); + } + + memory_prefetch_func +} + +mod mem_prefetch_funcs { + pub use polars_utils::mem::{ + madvise_populate_read, madvise_sequential, madvise_willneed, prefetch_l2, + }; + + pub fn no_prefetch(_: &[u8]) {} +} diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs index 4dc4d859ba62..2ce9ee2c9464 100644 --- a/crates/polars-stream/src/nodes/reduce.rs +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -1,7 +1,8 @@ use std::sync::Arc; -use polars_core::schema::Schema; -use polars_expr::reduce::Reduction; +use polars_core::schema::{Schema, SchemaExt}; +use polars_expr::reduce::{Reduction, ReductionState}; +use polars_utils::itertools::Itertools; use super::compute_node_prelude::*; use crate::expression::StreamExpr; @@ -11,6 +12,7 @@ enum ReduceState { Sink { selectors: Vec, reductions: Vec>, + reduction_states: Vec>, }, Source(Option), Done, @@ -27,10 +29,12 @@ impl ReduceNode { reductions: Vec>, output_schema: Arc, ) -> Self { + let reduction_states = reductions.iter().map(|r| r.new_reducer()).collect(); Self { state: ReduceState::Sink { selectors, reductions, + reduction_states, }, output_schema, } @@ -39,6 +43,7 @@ impl ReduceNode { fn spawn_sink<'env, 's>( selectors: &'env [StreamExpr], reductions: &'env mut [Box], + reduction_states: &'env mut [Box], scope: &'s TaskScope<'s, 'env>, recv: RecvPort<'_>, state: &'s ExecutionState, @@ -48,27 +53,27 @@ impl ReduceNode { .parallel() .into_iter() .map(|mut recv| { - let mut local_reductions: Vec<_> = - reductions.iter().map(|d| d.init_dyn()).collect(); + let mut local_reducers: Vec<_> = + reductions.iter().map(|d| d.new_reducer()).collect(); scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = recv.recv().await { - for (reduction, selector) in local_reductions.iter_mut().zip(selectors) { + for (reducer, selector) in local_reducers.iter_mut().zip(selectors) { // TODO: don't convert to physical representation here. let input = selector.evaluate(morsel.df(), state).await?; - reduction.update(&input.to_physical_repr())?; + reducer.update(&input.to_physical_repr())?; } } - PolarsResult::Ok(local_reductions) + PolarsResult::Ok(local_reducers) }) }) .collect(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { for task in parallel_tasks { - let local_reductions = task.await?; - for (r1, r2) in reductions.iter_mut().zip(local_reductions) { + let local_reducers = task.await?; + for (r1, r2) in reduction_states.iter_mut().zip(local_reducers) { r1.combine(&*r2)?; } } @@ -97,7 +102,7 @@ impl ComputeNode for ReduceNode { "reduce" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 1 && send.len() == 1); // State transitions. @@ -107,19 +112,22 @@ impl ComputeNode for ReduceNode { self.state = ReduceState::Done; }, // Input is done, transition to being a source. - ReduceState::Sink { reductions, .. } if matches!(recv[0], PortState::Done) => { - // TODO! make `update_state` fallible. - let columns = reductions + ReduceState::Sink { + reduction_states, .. + } if matches!(recv[0], PortState::Done) => { + let columns = reduction_states .iter_mut() .zip(self.output_schema.iter_fields()) .map(|(r, field)| { r.finalize().map(|scalar| { - scalar.into_series(&field.name).cast(&field.dtype).unwrap() + scalar + .into_series(field.name.clone()) + .cast(&field.dtype) + .unwrap() }) }) - .collect::>>() - .unwrap(); - let out = unsafe { DataFrame::new_no_checks(columns) }; + .try_collect_vec()?; + let out = DataFrame::new(columns).unwrap(); self.state = ReduceState::Source(Some(out)); }, @@ -146,6 +154,7 @@ impl ComputeNode for ReduceNode { send[0] = PortState::Done; }, } + Ok(()) } fn spawn<'env, 's>( @@ -161,10 +170,19 @@ impl ComputeNode for ReduceNode { ReduceState::Sink { selectors, reductions, + reduction_states, } => { assert!(send[0].is_none()); let recv_port = recv[0].take().unwrap(); - Self::spawn_sink(selectors, reductions, scope, recv_port, state, join_handles) + Self::spawn_sink( + selectors, + reductions, + reduction_states, + scope, + recv_port, + state, + join_handles, + ) }, ReduceState::Source(df) => { assert!(recv[0].is_none()); diff --git a/crates/polars-stream/src/nodes/select.rs b/crates/polars-stream/src/nodes/select.rs index 568351ee4f47..688580e10319 100644 --- a/crates/polars-stream/src/nodes/select.rs +++ b/crates/polars-stream/src/nodes/select.rs @@ -26,9 +26,10 @@ impl ComputeNode for SelectNode { "select" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 1 && send.len() == 1); recv.swap_with_slice(send); + Ok(()) } fn spawn<'env, 's>( @@ -59,20 +60,7 @@ impl ComputeNode for SelectNode { out._add_columns(selected, &slf.schema)?; out } else { - // Broadcast scalars. - let max_non_unit_length = selected - .iter() - .map(|s| s.len()) - .filter(|l| *l != 1) - .max() - .unwrap_or(1); - for s in &mut selected { - if s.len() != max_non_unit_length { - assert!(s.len() == 1, "got series of incompatible lengths"); - *s = s.new_from_index(0, max_non_unit_length); - } - } - unsafe { DataFrame::new_no_checks(selected) } + DataFrame::new_with_broadcast(selected)? }; let mut morsel = Morsel::new(ret, seq, source_token); diff --git a/crates/polars-stream/src/nodes/simple_projection.rs b/crates/polars-stream/src/nodes/simple_projection.rs index 1a643b642e73..95f002df2889 100644 --- a/crates/polars-stream/src/nodes/simple_projection.rs +++ b/crates/polars-stream/src/nodes/simple_projection.rs @@ -1,16 +1,17 @@ use std::sync::Arc; use polars_core::schema::Schema; +use polars_utils::pl_str::PlSmallStr; use super::compute_node_prelude::*; pub struct SimpleProjectionNode { - columns: Vec, + columns: Vec, input_schema: Arc, } impl SimpleProjectionNode { - pub fn new(columns: Vec, input_schema: Arc) -> Self { + pub fn new(columns: Vec, input_schema: Arc) -> Self { Self { columns, input_schema, @@ -23,9 +24,10 @@ impl ComputeNode for SimpleProjectionNode { "simple_projection" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(recv.len() == 1 && send.len() == 1); recv.swap_with_slice(send); + Ok(()) } fn spawn<'env, 's>( @@ -46,7 +48,12 @@ impl ComputeNode for SimpleProjectionNode { while let Ok(morsel) = recv.recv().await { let morsel = morsel.try_map(|df| { // TODO: can this be unchecked? - df.select_with_schema(&slf.columns, &slf.input_schema) + let check_duplicates = true; + df._select_with_schema_impl( + slf.columns.as_slice(), + &slf.input_schema, + check_duplicates, + ) })?; if send.send(morsel).await.is_err() { diff --git a/crates/polars-stream/src/nodes/streaming_slice.rs b/crates/polars-stream/src/nodes/streaming_slice.rs index b46693bac808..950b39331588 100644 --- a/crates/polars-stream/src/nodes/streaming_slice.rs +++ b/crates/polars-stream/src/nodes/streaming_slice.rs @@ -30,13 +30,14 @@ impl ComputeNode for StreamingSliceNode { self.num_pipelines = num_pipelines; } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { if self.stream_offset >= self.start_offset + self.length || self.length == 0 { recv[0] = PortState::Done; send[0] = PortState::Done; } else { recv.swap_with_slice(send); } + Ok(()) } fn spawn<'env, 's>( diff --git a/crates/polars-stream/src/nodes/zip.rs b/crates/polars-stream/src/nodes/zip.rs index b5b860880a1b..3a9290bde59d 100644 --- a/crates/polars-stream/src/nodes/zip.rs +++ b/crates/polars-stream/src/nodes/zip.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use polars_core::functions::concat_df_horizontal; use polars_core::schema::Schema; use polars_core::series::Series; +use polars_error::polars_ensure; use super::compute_node_prelude::*; use crate::morsel::SourceToken; @@ -94,7 +95,7 @@ impl InputHead { } else { self.schema .iter() - .map(|(name, dtype)| Series::full_null(name, len, dtype)) + .map(|(name, dtype)| Series::full_null(name.clone(), len, dtype)) .collect() } } @@ -138,7 +139,7 @@ impl ComputeNode for ZipNode { "zip" } - fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) -> PolarsResult<()> { assert!(send.len() == 1); assert!(recv.len() == self.input_heads.len()); @@ -167,9 +168,9 @@ impl ComputeNode for ZipNode { } if !self.null_extend { - assert!( + polars_ensure!( !(at_least_one_non_broadcast_done && at_least_one_non_broadcast_nonempty), - "zip received non-equal length inputs" + ShapeMismatch: "zip node received non-equal length inputs" ); } @@ -196,6 +197,7 @@ impl ComputeNode for ZipNode { for r in recv { *r = new_recv_state; } + Ok(()) } fn spawn<'env, 's>( diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs new file mode 100644 index 000000000000..8a3e7a1b8ac4 --- /dev/null +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -0,0 +1,199 @@ +use std::fmt::Write; + +use polars_plan::plans::expr_ir::ExprIR; +use polars_plan::plans::{AExpr, EscapeLabel, FileScan, PathsDisplay}; +use polars_utils::arena::Arena; +use polars_utils::itertools::Itertools; +use slotmap::{Key, SecondaryMap, SlotMap}; + +use super::{PhysNode, PhysNodeKey, PhysNodeKind}; + +fn escape_graphviz(s: &str) -> String { + s.replace('\\', "\\\\") + .replace('\n', "\\n") + .replace('"', "\\\"") +} + +fn fmt_exprs(exprs: &[ExprIR], expr_arena: &Arena) -> String { + exprs + .iter() + .map(|e| escape_graphviz(&e.display(expr_arena).to_string())) + .collect_vec() + .join("\\n") +} + +#[recursive::recursive] +fn visualize_plan_rec( + node_key: PhysNodeKey, + phys_sm: &SlotMap, + expr_arena: &Arena, + visited: &mut SecondaryMap, + out: &mut Vec, +) { + if visited.contains_key(node_key) { + return; + } + visited.insert(node_key, ()); + + use std::slice::from_ref; + let (label, inputs) = match &phys_sm[node_key].kind { + PhysNodeKind::InMemorySource { df } => ( + format!( + "in-memory-source\\ncols: {}", + df.get_column_names_owned().join(", ") + ), + &[][..], + ), + PhysNodeKind::Select { + input, + selectors, + extend_original, + } => { + let label = if *extend_original { + "with-columns" + } else { + "select" + }; + ( + format!("{label}\\n{}", fmt_exprs(selectors, expr_arena)), + from_ref(input), + ) + }, + PhysNodeKind::Reduce { input, exprs } => ( + format!("reduce\\n{}", fmt_exprs(exprs, expr_arena)), + from_ref(input), + ), + PhysNodeKind::StreamingSlice { + input, + offset, + length, + } => ( + format!("slice\\noffset: {offset}, length: {length}"), + from_ref(input), + ), + PhysNodeKind::Filter { input, predicate } => ( + format!("filter\\n{}", fmt_exprs(from_ref(predicate), expr_arena)), + from_ref(input), + ), + PhysNodeKind::SimpleProjection { input, columns } => ( + format!("select\\ncols: {}", columns.join(", ")), + from_ref(input), + ), + PhysNodeKind::InMemorySink { input } => ("in-memory-sink".to_string(), from_ref(input)), + PhysNodeKind::InMemoryMap { input, map: _ } => { + ("in-memory-map".to_string(), from_ref(input)) + }, + PhysNodeKind::Map { input, map: _ } => ("map".to_string(), from_ref(input)), + PhysNodeKind::Sort { + input, + by_column, + slice: _, + sort_options: _, + } => ( + format!("sort\\n{}", fmt_exprs(by_column, expr_arena)), + from_ref(input), + ), + PhysNodeKind::OrderedUnion { inputs } => ("ordered-union".to_string(), inputs.as_slice()), + PhysNodeKind::Zip { + inputs, + null_extend, + } => { + let label = if *null_extend { + "zip-null-extend" + } else { + "zip" + }; + (label.to_string(), inputs.as_slice()) + }, + PhysNodeKind::Multiplexer { input } => ("multiplexer".to_string(), from_ref(input)), + PhysNodeKind::FileScan { + paths, + file_info, + hive_parts, + output_schema: _, + scan_type, + predicate, + file_options, + } => { + let name = match scan_type { + FileScan::Parquet { .. } => "parquet-source", + FileScan::Csv { .. } => "csv-source", + FileScan::Ipc { .. } => "ipc-source", + FileScan::NDJson { .. } => "ndjson-source", + FileScan::Anonymous { .. } => "anonymous-source", + }; + + let mut out = name.to_string(); + let mut f = EscapeLabel(&mut out); + + { + let paths_display = PathsDisplay(paths.as_ref()); + + write!(f, "\npaths: {}", paths_display).unwrap(); + } + + { + let total_columns = + file_info.schema.len() - usize::from(file_options.row_index.is_some()); + let n_columns = file_options + .with_columns + .as_ref() + .map(|columns| columns.len()); + + if let Some(n) = n_columns { + write!(f, "\nprojection: {}/{total_columns}", n).unwrap(); + } else { + write!(f, "\nprojection: */{total_columns}").unwrap(); + } + } + + if let Some(polars_io::RowIndex { name, offset }) = &file_options.row_index { + write!(f, r#"\nrow index: name: "{}", offset: {}"#, name, offset).unwrap(); + } + + if let Some((offset, len)) = file_options.slice { + write!(f, "\nslice: offset: {}, len: {}", offset, len).unwrap(); + } + + if let Some(predicate) = predicate.as_ref() { + write!(f, "\nfilter: {}", predicate.display(expr_arena)).unwrap(); + } + + if let Some(v) = hive_parts + .as_deref() + .map(|x| x[0].get_statistics().column_stats().len()) + { + write!(f, "\nhive: {} columns", v).unwrap(); + } + + (out, &[][..]) + }, + }; + + out.push(format!( + "{} [label=\"{}\"];", + node_key.data().as_ffi(), + label + )); + for input in inputs { + visualize_plan_rec(*input, phys_sm, expr_arena, visited, out); + out.push(format!( + "{} -> {};", + input.data().as_ffi(), + node_key.data().as_ffi() + )); + } +} + +pub fn visualize_plan( + root: PhysNodeKey, + phys_sm: &SlotMap, + expr_arena: &Arena, +) -> String { + let mut visited: SecondaryMap = SecondaryMap::new(); + let mut out = Vec::with_capacity(phys_sm.len() + 2); + out.push("digraph polars {\nrankdir=\"BT\"".to_string()); + visualize_plan_rec(root, phys_sm, expr_arena, &mut visited, &mut out); + out.push("}".to_string()); + out.join("\n") +} diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs new file mode 100644 index 000000000000..af0e138ec30c --- /dev/null +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -0,0 +1,767 @@ +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +use polars_core::frame::DataFrame; +use polars_core::prelude::{Field, InitHashMaps, PlHashMap, PlHashSet}; +use polars_core::schema::{Schema, SchemaExt}; +use polars_error::PolarsResult; +use polars_expr::planner::get_expr_depth_limit; +use polars_expr::state::ExecutionState; +use polars_expr::{create_physical_expr, ExpressionConversionState}; +use polars_plan::plans::expr_ir::{ExprIR, OutputName}; +use polars_plan::plans::{AExpr, LiteralValue}; +use polars_plan::prelude::*; +use polars_utils::arena::{Arena, Node}; +use polars_utils::format_pl_smallstr; +use polars_utils::itertools::Itertools; +use polars_utils::pl_str::PlSmallStr; +use slotmap::SlotMap; + +use super::{PhysNode, PhysNodeKey, PhysNodeKind}; + +type IRNodeKey = Node; + +fn unique_column_name() -> PlSmallStr { + static COUNTER: AtomicU64 = AtomicU64::new(0); + let idx = COUNTER.fetch_add(1, Ordering::Relaxed); + format_pl_smallstr!("__POLARS_STMP_{idx}") +} + +pub(crate) struct ExprCache { + is_elementwise: PlHashMap, + is_input_independent: PlHashMap, +} + +impl ExprCache { + pub fn with_capacity(capacity: usize) -> Self { + Self { + is_elementwise: PlHashMap::with_capacity(capacity), + is_input_independent: PlHashMap::with_capacity(capacity), + } + } +} + +struct LowerExprContext<'a> { + expr_arena: &'a mut Arena, + phys_sm: &'a mut SlotMap, + cache: &'a mut ExprCache, +} + +#[recursive::recursive] +pub(crate) fn is_elementwise( + expr_key: IRNodeKey, + arena: &Arena, + cache: &mut ExprCache, +) -> bool { + if let Some(ret) = cache.is_elementwise.get(&expr_key) { + return *ret; + } + + let ret = match arena.get(expr_key) { + AExpr::Explode(_) => false, + AExpr::Alias(inner, _) => is_elementwise(*inner, arena, cache), + AExpr::Column(_) => true, + AExpr::Literal(lit) => !matches!(lit, LiteralValue::Series(_) | LiteralValue::Range { .. }), + AExpr::BinaryExpr { left, op: _, right } => { + is_elementwise(*left, arena, cache) && is_elementwise(*right, arena, cache) + }, + AExpr::Cast { + expr, + dtype: _, + options: _, + } => is_elementwise(*expr, arena, cache), + AExpr::Sort { .. } | AExpr::SortBy { .. } | AExpr::Gather { .. } => false, + AExpr::Filter { .. } => false, + AExpr::Agg(_) => false, + AExpr::Ternary { + predicate, + truthy, + falsy, + } => { + is_elementwise(*predicate, arena, cache) + && is_elementwise(*truthy, arena, cache) + && is_elementwise(*falsy, arena, cache) + }, + AExpr::AnonymousFunction { + input, + function: _, + output_type: _, + options, + } + | AExpr::Function { + input, + function: _, + options, + } => { + options.is_elementwise() && input.iter().all(|e| is_elementwise(e.node(), arena, cache)) + }, + + AExpr::Window { .. } => false, + AExpr::Slice { .. } => false, + AExpr::Len => false, + }; + + cache.is_elementwise.insert(expr_key, ret); + ret +} + +#[recursive::recursive] +fn is_input_independent_rec( + expr_key: IRNodeKey, + arena: &Arena, + cache: &mut PlHashMap, +) -> bool { + if let Some(ret) = cache.get(&expr_key) { + return *ret; + } + + let ret = match arena.get(expr_key) { + AExpr::Explode(inner) + | AExpr::Alias(inner, _) + | AExpr::Cast { + expr: inner, + dtype: _, + options: _, + } + | AExpr::Sort { + expr: inner, + options: _, + } => is_input_independent_rec(*inner, arena, cache), + AExpr::Column(_) => false, + AExpr::Literal(_) => true, + AExpr::BinaryExpr { left, op: _, right } => { + is_input_independent_rec(*left, arena, cache) + && is_input_independent_rec(*right, arena, cache) + }, + AExpr::Gather { + expr, + idx, + returns_scalar: _, + } => { + is_input_independent_rec(*expr, arena, cache) + && is_input_independent_rec(*idx, arena, cache) + }, + AExpr::SortBy { + expr, + by, + sort_options: _, + } => { + is_input_independent_rec(*expr, arena, cache) + && by + .iter() + .all(|expr| is_input_independent_rec(*expr, arena, cache)) + }, + AExpr::Filter { input, by } => { + is_input_independent_rec(*input, arena, cache) + && is_input_independent_rec(*by, arena, cache) + }, + AExpr::Agg(agg_expr) => match agg_expr.get_input() { + polars_plan::plans::NodeInputs::Leaf => true, + polars_plan::plans::NodeInputs::Single(expr) => { + is_input_independent_rec(expr, arena, cache) + }, + polars_plan::plans::NodeInputs::Many(exprs) => exprs + .iter() + .all(|expr| is_input_independent_rec(*expr, arena, cache)), + }, + AExpr::Ternary { + predicate, + truthy, + falsy, + } => { + is_input_independent_rec(*predicate, arena, cache) + && is_input_independent_rec(*truthy, arena, cache) + && is_input_independent_rec(*falsy, arena, cache) + }, + AExpr::AnonymousFunction { + input, + function: _, + output_type: _, + options: _, + } + | AExpr::Function { + input, + function: _, + options: _, + } => input + .iter() + .all(|expr| is_input_independent_rec(expr.node(), arena, cache)), + AExpr::Window { + function, + partition_by, + order_by, + options: _, + } => { + is_input_independent_rec(*function, arena, cache) + && partition_by + .iter() + .all(|expr| is_input_independent_rec(*expr, arena, cache)) + && order_by + .iter() + .all(|(expr, _options)| is_input_independent_rec(*expr, arena, cache)) + }, + AExpr::Slice { + input, + offset, + length, + } => { + is_input_independent_rec(*input, arena, cache) + && is_input_independent_rec(*offset, arena, cache) + && is_input_independent_rec(*length, arena, cache) + }, + AExpr::Len => false, + }; + + cache.insert(expr_key, ret); + ret +} + +fn is_input_independent(expr_key: IRNodeKey, ctx: &mut LowerExprContext) -> bool { + is_input_independent_rec( + expr_key, + ctx.expr_arena, + &mut ctx.cache.is_input_independent, + ) +} + +fn build_input_independent_node_with_ctx( + exprs: &[ExprIR], + ctx: &mut LowerExprContext, +) -> PolarsResult { + let expr_depth_limit = get_expr_depth_limit()?; + let mut state = ExpressionConversionState::new(false, expr_depth_limit); + let empty = DataFrame::empty(); + let execution_state = ExecutionState::new(); + let columns = exprs + .iter() + .map(|expr| { + let phys_expr = + create_physical_expr(expr, Context::Default, ctx.expr_arena, None, &mut state)?; + + phys_expr.evaluate(&empty, &execution_state) + }) + .try_collect_vec()?; + + let df = Arc::new(DataFrame::new_with_broadcast(columns)?); + Ok(ctx.phys_sm.insert(PhysNode::new( + Arc::new(df.schema()), + PhysNodeKind::InMemorySource { df }, + ))) +} + +fn simplify_input_nodes( + orig_input: PhysNodeKey, + mut input_nodes: PlHashSet, + ctx: &mut LowerExprContext, +) -> PolarsResult> { + // Flatten nested zips (ensures the original input columns only occur once). + if input_nodes.len() > 1 { + let mut flattened_input_nodes = PlHashSet::with_capacity(input_nodes.len()); + for input_node in input_nodes { + if let PhysNodeKind::Zip { + inputs, + null_extend: false, + } = &ctx.phys_sm[input_node].kind + { + flattened_input_nodes.extend(inputs); + ctx.phys_sm.remove(input_node); + } else { + flattened_input_nodes.insert(input_node); + } + } + input_nodes = flattened_input_nodes; + } + + // Merge reduce nodes that directly operate on the original input. + let mut combined_exprs = vec![]; + input_nodes = input_nodes + .into_iter() + .filter(|input_node| { + if let PhysNodeKind::Reduce { + input: inner, + exprs, + } = &ctx.phys_sm[*input_node].kind + { + if *inner == orig_input { + combined_exprs.extend(exprs.iter().cloned()); + ctx.phys_sm.remove(*input_node); + return false; + } + } + true + }) + .collect(); + if !combined_exprs.is_empty() { + let output_schema = schema_for_select(orig_input, &combined_exprs, ctx)?; + let kind = PhysNodeKind::Reduce { + input: orig_input, + exprs: combined_exprs, + }; + let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + input_nodes.insert(reduce_node_key); + } + + Ok(input_nodes) +} + +fn build_fallback_node_with_ctx( + input: PhysNodeKey, + exprs: &[ExprIR], + ctx: &mut LowerExprContext, +) -> PolarsResult { + // Pre-select only the columns that are needed for this fallback expression. + let input_schema = &ctx.phys_sm[input].output_schema; + let select_names: PlHashSet<_> = exprs + .iter() + .flat_map(|expr| polars_plan::utils::aexpr_to_leaf_names_iter(expr.node(), ctx.expr_arena)) + .collect(); + let input_node = if input_schema + .iter_names() + .any(|name| !select_names.contains(name.as_str())) + { + let select_exprs = select_names + .into_iter() + .map(|name| { + ExprIR::new( + ctx.expr_arena.add(AExpr::Column(name.clone())), + OutputName::ColumnLhs(name), + ) + }) + .collect_vec(); + build_select_node_with_ctx(input, &select_exprs, ctx)? + } else { + input + }; + + let output_schema = schema_for_select(input_node, exprs, ctx)?; + let expr_depth_limit = get_expr_depth_limit()?; + let mut conv_state = ExpressionConversionState::new(false, expr_depth_limit); + let phys_exprs = exprs + .iter() + .map(|expr| { + create_physical_expr( + expr, + Context::Default, + ctx.expr_arena, + None, + &mut conv_state, + ) + }) + .try_collect_vec()?; + let map = move |df| { + let exec_state = ExecutionState::new(); + let columns = phys_exprs + .iter() + .map(|phys_expr| phys_expr.evaluate(&df, &exec_state)) + .try_collect()?; + DataFrame::new_with_broadcast(columns) + }; + let kind = PhysNodeKind::InMemoryMap { + input: input_node, + map: Arc::new(map), + }; + Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, kind))) +} + +// In the recursive lowering we don't bother with named expressions at all, so +// we work directly with Nodes. +#[recursive::recursive] +fn lower_exprs_with_ctx( + input: PhysNodeKey, + exprs: &[Node], + ctx: &mut LowerExprContext, +) -> PolarsResult<(PhysNodeKey, Vec)> { + // We have to catch this case separately, in case all the input independent expressions are elementwise. + // TODO: we shouldn't always do this when recursing, e.g. pl.col.a.sum() + 1 will still hit this in the recursion. + if exprs.iter().all(|e| is_input_independent(*e, ctx)) { + let expr_irs = exprs + .iter() + .map(|e| ExprIR::new(*e, OutputName::Alias(unique_column_name()))) + .collect_vec(); + let node = build_input_independent_node_with_ctx(&expr_irs, ctx)?; + let out_exprs = expr_irs + .iter() + .map(|e| ctx.expr_arena.add(AExpr::Column(e.output_name().clone()))) + .collect(); + return Ok((node, out_exprs)); + } + + // Fallback expressions that can directly be applied to the original input. + let mut fallback_subset = Vec::new(); + + // Nodes containing the columns used for executing transformed expressions. + let mut input_nodes = PlHashSet::new(); + + // The final transformed expressions that will be selected from the zipped + // together transformed nodes. + let mut transformed_exprs = Vec::with_capacity(exprs.len()); + + for expr in exprs.iter().copied() { + if is_elementwise(expr, ctx.expr_arena, ctx.cache) { + if !is_input_independent(expr, ctx) { + input_nodes.insert(input); + } + transformed_exprs.push(expr); + continue; + } + + match ctx.expr_arena.get(expr).clone() { + AExpr::Explode(inner) => { + // While explode is streamable, it is not elementwise, so we + // have to transform it to a select node. + let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[inner], ctx)?; + let exploded_name = unique_column_name(); + let trans_inner = ctx.expr_arena.add(AExpr::Explode(trans_exprs[0])); + let explode_expr = + ExprIR::new(trans_inner, OutputName::Alias(exploded_name.clone())); + let output_schema = schema_for_select(trans_input, &[explode_expr.clone()], ctx)?; + let node_kind = PhysNodeKind::Select { + input: trans_input, + selectors: vec![explode_expr.clone()], + extend_original: false, + }; + let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind)); + input_nodes.insert(node_key); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(exploded_name))); + }, + AExpr::Alias(_, _) => unreachable!("alias found in physical plan"), + AExpr::Column(_) => unreachable!("column should always be streamable"), + AExpr::Literal(_) => { + let out_name = unique_column_name(); + let inner_expr = ExprIR::new(expr, OutputName::Alias(out_name.clone())); + input_nodes.insert(build_input_independent_node_with_ctx(&[inner_expr], ctx)?); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); + }, + AExpr::BinaryExpr { left, op, right } => { + let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[left, right], ctx)?; + let bin_expr = AExpr::BinaryExpr { + left: trans_exprs[0], + op, + right: trans_exprs[1], + }; + input_nodes.insert(trans_input); + transformed_exprs.push(ctx.expr_arena.add(bin_expr)); + }, + AExpr::Ternary { + predicate, + truthy, + falsy, + } => { + let (trans_input, trans_exprs) = + lower_exprs_with_ctx(input, &[predicate, truthy, falsy], ctx)?; + let tern_expr = AExpr::Ternary { + predicate: trans_exprs[0], + truthy: trans_exprs[1], + falsy: trans_exprs[2], + }; + input_nodes.insert(trans_input); + transformed_exprs.push(ctx.expr_arena.add(tern_expr)); + }, + AExpr::Cast { + expr: inner, + dtype, + options, + } => { + let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[inner], ctx)?; + input_nodes.insert(trans_input); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Cast { + expr: trans_exprs[0], + dtype, + options, + })); + }, + AExpr::Sort { + expr: inner, + options, + } => { + // As we'll refer to the sorted column twice, ensure the inner + // expr is available as a column by selecting first. + let sorted_name = unique_column_name(); + let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(sorted_name.clone())); + let select_node = build_select_node_with_ctx(input, &[inner_expr_ir.clone()], ctx)?; + let col_expr = ctx.expr_arena.add(AExpr::Column(sorted_name.clone())); + let kind = PhysNodeKind::Sort { + input: select_node, + by_column: vec![ExprIR::new(col_expr, OutputName::Alias(sorted_name))], + slice: None, + sort_options: (&options).into(), + }; + let output_schema = ctx.phys_sm[select_node].output_schema.clone(); + let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + input_nodes.insert(node_key); + transformed_exprs.push(col_expr); + }, + AExpr::SortBy { + expr: inner, + by, + sort_options, + } => { + // Select our inputs (if we don't do this we'll waste time sorting irrelevant columns). + let sorted_name = unique_column_name(); + let by_names = by.iter().map(|_| unique_column_name()).collect_vec(); + let all_inner_expr_irs = [(&sorted_name, inner)] + .into_iter() + .chain(by_names.iter().zip(by.iter().copied())) + .map(|(name, inner)| ExprIR::new(inner, OutputName::Alias(name.clone()))) + .collect_vec(); + let select_node = build_select_node_with_ctx(input, &all_inner_expr_irs, ctx)?; + + // Sort the inputs. + let kind = PhysNodeKind::Sort { + input: select_node, + by_column: by_names + .into_iter() + .map(|name| { + ExprIR::new( + ctx.expr_arena.add(AExpr::Column(name.clone())), + OutputName::Alias(name), + ) + }) + .collect(), + slice: None, + sort_options, + }; + let output_schema = ctx.phys_sm[select_node].output_schema.clone(); + let sort_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + + // Drop the by columns. + let sorted_col_expr = ctx.expr_arena.add(AExpr::Column(sorted_name.clone())); + let sorted_col_ir = + ExprIR::new(sorted_col_expr, OutputName::Alias(sorted_name.clone())); + let post_sort_select_node = + build_select_node_with_ctx(sort_node_key, &[sorted_col_ir], ctx)?; + input_nodes.insert(post_sort_select_node); + transformed_exprs.push(sorted_col_expr); + }, + AExpr::Gather { .. } => todo!(), + AExpr::Filter { input: inner, by } => { + // Select our inputs (if we don't do this we'll waste time filtering irrelevant columns). + let out_name = unique_column_name(); + let by_name = unique_column_name(); + let inner_expr_ir = ExprIR::new(inner, OutputName::Alias(out_name.clone())); + let by_expr_ir = ExprIR::new(by, OutputName::Alias(by_name.clone())); + let select_node = + build_select_node_with_ctx(input, &[inner_expr_ir, by_expr_ir], ctx)?; + + // Add a filter node. + let predicate = ExprIR::new( + ctx.expr_arena.add(AExpr::Column(by_name.clone())), + OutputName::Alias(by_name), + ); + let kind = PhysNodeKind::Filter { + input: select_node, + predicate, + }; + let output_schema = ctx.phys_sm[select_node].output_schema.clone(); + let filter_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + input_nodes.insert(filter_node_key); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); + }, + AExpr::Agg(mut agg) => match agg { + // Change agg mutably so we can share the codepath for all of these. + IRAggExpr::Min { + input: ref mut inner, + .. + } + | IRAggExpr::Max { + input: ref mut inner, + .. + } + | IRAggExpr::Sum(ref mut inner) + | IRAggExpr::Mean(ref mut inner) => { + let (trans_input, trans_exprs) = lower_exprs_with_ctx(input, &[*inner], ctx)?; + *inner = trans_exprs[0]; + + let out_name = unique_column_name(); + let trans_agg_expr = ctx.expr_arena.add(AExpr::Agg(agg)); + let expr_ir = ExprIR::new(trans_agg_expr, OutputName::Alias(out_name.clone())); + let output_schema = schema_for_select(trans_input, &[expr_ir.clone()], ctx)?; + let kind = PhysNodeKind::Reduce { + input: trans_input, + exprs: vec![expr_ir], + }; + let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + input_nodes.insert(reduce_node_key); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); + }, + IRAggExpr::Median(_) + | IRAggExpr::NUnique(_) + | IRAggExpr::First(_) + | IRAggExpr::Last(_) + | IRAggExpr::Implode(_) + | IRAggExpr::Quantile { .. } + | IRAggExpr::Count(_, _) + | IRAggExpr::Std(_, _) + | IRAggExpr::Var(_, _) + | IRAggExpr::AggGroups(_) => { + let out_name = unique_column_name(); + fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone()))); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); + }, + }, + AExpr::Len => { + let out_name = unique_column_name(); + let expr_ir = ExprIR::new(expr, OutputName::Alias(out_name.clone())); + let output_schema = schema_for_select(input, &[expr_ir.clone()], ctx)?; + let kind = PhysNodeKind::Reduce { + input, + exprs: vec![expr_ir], + }; + let reduce_node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, kind)); + input_nodes.insert(reduce_node_key); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); + }, + AExpr::AnonymousFunction { .. } + | AExpr::Function { .. } + | AExpr::Slice { .. } + | AExpr::Window { .. } => { + let out_name = unique_column_name(); + fallback_subset.push(ExprIR::new(expr, OutputName::Alias(out_name.clone()))); + transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name))); + }, + } + } + + if !fallback_subset.is_empty() { + input_nodes.insert(build_fallback_node_with_ctx(input, &fallback_subset, ctx)?); + } + + // Simplify the input nodes (also ensures the original input only occurs + // once in the zip). + input_nodes = simplify_input_nodes(input, input_nodes, ctx)?; + + if input_nodes.len() == 1 { + // No need for any multiplexing/zipping, can directly execute. + return Ok((input_nodes.into_iter().next().unwrap(), transformed_exprs)); + } + + let zip_inputs = input_nodes.into_iter().collect_vec(); + let output_schema = zip_inputs + .iter() + .flat_map(|node| ctx.phys_sm[*node].output_schema.iter_fields()) + .collect(); + let zip_kind = PhysNodeKind::Zip { + inputs: zip_inputs, + null_extend: false, + }; + let zip_node = ctx + .phys_sm + .insert(PhysNode::new(Arc::new(output_schema), zip_kind)); + + Ok((zip_node, transformed_exprs)) +} + +/// Computes the schema that selecting the given expressions on the input node +/// would result in. +fn schema_for_select( + input: PhysNodeKey, + exprs: &[ExprIR], + ctx: &mut LowerExprContext, +) -> PolarsResult> { + let input_schema = &ctx.phys_sm[input].output_schema; + let output_schema: Schema = exprs + .iter() + .map(|e| { + let name = e.output_name().clone(); + let dtype = ctx.expr_arena.get(e.node()).to_dtype( + input_schema, + Context::Default, + ctx.expr_arena, + )?; + PolarsResult::Ok(Field::new(name, dtype)) + }) + .try_collect()?; + Ok(Arc::new(output_schema)) +} + +fn build_select_node_with_ctx( + input: PhysNodeKey, + exprs: &[ExprIR], + ctx: &mut LowerExprContext, +) -> PolarsResult { + if exprs.iter().all(|e| is_input_independent(e.node(), ctx)) { + return build_input_independent_node_with_ctx(exprs, ctx); + } + + // Are we only selecting simple columns, with the same name? + let all_simple_columns: Option> = exprs + .iter() + .map(|e| match ctx.expr_arena.get(e.node()) { + AExpr::Column(name) if name == e.output_name() => Some(name.clone()), + _ => None, + }) + .collect(); + + if let Some(columns) = all_simple_columns { + let input_schema = ctx.phys_sm[input].output_schema.clone(); + if input_schema.len() == columns.len() + && input_schema.iter_names().zip(&columns).all(|(l, r)| l == r) + { + // Input node already has the correct schema, just pass through. + return Ok(input); + } + + let output_schema = Arc::new(input_schema.try_project(&columns)?); + let node_kind = PhysNodeKind::SimpleProjection { input, columns }; + return Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind))); + } + + let node_exprs = exprs.iter().map(|e| e.node()).collect_vec(); + let (transformed_input, transformed_exprs) = lower_exprs_with_ctx(input, &node_exprs, ctx)?; + let trans_expr_irs = exprs + .iter() + .zip(transformed_exprs) + .map(|(e, te)| ExprIR::new(te, OutputName::Alias(e.output_name().clone()))) + .collect_vec(); + let output_schema = schema_for_select(transformed_input, &trans_expr_irs, ctx)?; + let node_kind = PhysNodeKind::Select { + input: transformed_input, + selectors: trans_expr_irs, + extend_original: false, + }; + Ok(ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind))) +} + +/// Lowers an input node plus a set of expressions on that input node to an +/// equivalent (input node, set of expressions) pair, ensuring that the new set +/// of expressions can run on the streaming engine. +/// +/// Ensures that if the input node is transformed it has unique column names. +pub fn lower_exprs( + input: PhysNodeKey, + exprs: &[ExprIR], + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, + expr_cache: &mut ExprCache, +) -> PolarsResult<(PhysNodeKey, Vec)> { + let mut ctx = LowerExprContext { + expr_arena, + phys_sm, + cache: expr_cache, + }; + let node_exprs = exprs.iter().map(|e| e.node()).collect_vec(); + let (transformed_input, transformed_exprs) = + lower_exprs_with_ctx(input, &node_exprs, &mut ctx)?; + let trans_expr_irs = exprs + .iter() + .zip(transformed_exprs) + .map(|(e, te)| ExprIR::new(te, OutputName::Alias(e.output_name().clone()))) + .collect_vec(); + Ok((transformed_input, trans_expr_irs)) +} + +/// Builds a selection node given an input node and the expressions to select for. +pub fn build_select_node( + input: PhysNodeKey, + exprs: &[ExprIR], + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, + expr_cache: &mut ExprCache, +) -> PolarsResult { + let mut ctx = LowerExprContext { + expr_arena, + phys_sm, + cache: expr_cache, + }; + build_select_node_with_ctx(input, exprs, &mut ctx) +} diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 75ae7daeb728..65044a717213 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -1,180 +1,266 @@ use std::sync::Arc; -use polars_error::PolarsResult; -use polars_expr::reduce::can_convert_into_reduction; -use polars_plan::plans::{AExpr, Context, IR}; +use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap}; +use polars_core::schema::Schema; +use polars_error::{polars_err, PolarsResult}; +use polars_plan::plans::expr_ir::{ExprIR, OutputName}; +use polars_plan::plans::{AExpr, IR}; use polars_plan::prelude::SinkType; use polars_utils::arena::{Arena, Node}; +use polars_utils::itertools::Itertools; use slotmap::SlotMap; -use super::{PhysNode, PhysNodeKey}; - -fn is_streamable(node: Node, arena: &Arena) -> bool { - polars_plan::plans::is_streamable(node, arena, Context::Default) -} +use super::{PhysNode, PhysNodeKey, PhysNodeKind}; +use crate::physical_plan::lower_expr::{is_elementwise, ExprCache}; #[recursive::recursive] pub fn lower_ir( node: Node, ir_arena: &mut Arena, - expr_arena: &Arena, + expr_arena: &mut Arena, phys_sm: &mut SlotMap, + schema_cache: &mut PlHashMap>, + expr_cache: &mut ExprCache, ) -> PolarsResult { let ir_node = ir_arena.get(node); - match ir_node { + let output_schema = IR::schema_with_cache(node, ir_arena, schema_cache); + let node_kind = match ir_node { IR::SimpleProjection { input, columns } => { - let input_ir_node = ir_arena.get(*input); - let input_schema = input_ir_node.schema(ir_arena).into_owned(); - let columns = columns.iter_names().map(|s| s.to_string()).collect(); - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - Ok(phys_sm.insert(PhysNode::SimpleProjection { - input, + let columns = columns.iter_names_cloned().collect::>(); + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + PhysNodeKind::SimpleProjection { + input: phys_input, columns, - input_schema, - })) + } }, - // TODO: split partially streamable selections to avoid fallback as much as possible. - IR::Select { - input, - expr, - schema, - .. - } if expr.iter().all(|e| is_streamable(e.node(), expr_arena)) => { + IR::Select { input, expr, .. } => { let selectors = expr.clone(); - let output_schema = schema.clone(); - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - Ok(phys_sm.insert(PhysNode::Select { - input, - selectors, - output_schema, - extend_original: false, - })) - }, - // TODO: split reductions and streamable selections. E.g. sum(a) + sum(b) should be split - // into Select(a + b) -> Reduce(sum(a), sum(b) - IR::Select { - input, - expr, - schema: output_schema, - .. - } if expr - .iter() - .all(|e| can_convert_into_reduction(e.node(), expr_arena)) => - { - let exprs = expr.clone(); - let input_ir_node = ir_arena.get(*input); - let input_schema = input_ir_node.schema(ir_arena).into_owned(); - let output_schema = output_schema.clone(); - let input_node = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - Ok(phys_sm.insert(PhysNode::Reduce { - input: input_node, - exprs, - input_schema, - output_schema, - })) + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + return super::lower_expr::build_select_node( + phys_input, &selectors, expr_arena, phys_sm, expr_cache, + ); }, - // TODO: split partially streamable selections to avoid fallback as much as possible. - IR::HStack { - input, - exprs, - schema, - .. - } if exprs.iter().all(|e| is_streamable(e.node(), expr_arena)) => { + IR::HStack { input, exprs, .. } + if exprs + .iter() + .all(|e| is_elementwise(e.node(), expr_arena, expr_cache)) => + { + // FIXME: constant literal columns should be broadcasted with hstack. let selectors = exprs.clone(); - let output_schema = schema.clone(); - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - Ok(phys_sm.insert(PhysNode::Select { - input, + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + PhysNodeKind::Select { + input: phys_input, selectors, - output_schema, extend_original: true, - })) + } + }, + + IR::HStack { input, exprs, .. } => { + // We already handled the all-streamable case above, so things get more complicated. + // For simplicity we just do a normal select with all the original columns prepended. + // + // FIXME: constant literal columns should be broadcasted with hstack. + let exprs = exprs.clone(); + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + let input_schema = &phys_sm[phys_input].output_schema; + let mut selectors = PlIndexMap::with_capacity(input_schema.len() + exprs.len()); + for name in input_schema.iter_names() { + let col_name = name.clone(); + let col_expr = expr_arena.add(AExpr::Column(col_name.clone())); + selectors.insert( + name.clone(), + ExprIR::new(col_expr, OutputName::ColumnLhs(col_name)), + ); + } + for expr in exprs { + selectors.insert(expr.output_name().clone(), expr); + } + let selectors = selectors.into_values().collect_vec(); + return super::lower_expr::build_select_node( + phys_input, &selectors, expr_arena, phys_sm, expr_cache, + ); }, IR::Slice { input, offset, len } => { if *offset >= 0 { let offset = *offset as usize; let length = *len as usize; - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - Ok(phys_sm.insert(PhysNode::StreamingSlice { - input, + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + PhysNodeKind::StreamingSlice { + input: phys_input, offset, length, - })) + } } else { todo!() } }, - IR::Filter { input, predicate } if is_streamable(predicate.node(), expr_arena) => { + IR::Filter { input, predicate } => { let predicate = predicate.clone(); - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - Ok(phys_sm.insert(PhysNode::Filter { input, predicate })) + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + let cols_and_predicate = output_schema + .iter_names() + .cloned() + .map(|name| { + ExprIR::new( + expr_arena.add(AExpr::Column(name.clone())), + OutputName::ColumnLhs(name), + ) + }) + .chain([predicate]) + .collect_vec(); + let (trans_input, mut trans_cols_and_predicate) = super::lower_expr::lower_exprs( + phys_input, + &cols_and_predicate, + expr_arena, + phys_sm, + expr_cache, + )?; + + let filter_schema = phys_sm[trans_input].output_schema.clone(); + let filter = PhysNodeKind::Filter { + input: trans_input, + predicate: trans_cols_and_predicate.last().unwrap().clone(), + }; + + let post_filter = phys_sm.insert(PhysNode::new(filter_schema, filter)); + trans_cols_and_predicate.pop(); // Remove predicate. + return super::lower_expr::build_select_node( + post_filter, + &trans_cols_and_predicate, + expr_arena, + phys_sm, + expr_cache, + ); }, IR::DataFrameScan { df, - output_schema, + output_schema: projection, filter, - schema: input_schema, + schema, .. } => { - if let Some(filter) = filter { - if !is_streamable(filter.node(), expr_arena) { - todo!() - } - } - - let mut phys_node = phys_sm.insert(PhysNode::InMemorySource { df: df.clone() }); + let mut schema = schema.clone(); // This is initially the schema of df, but can change with the projection. + let mut node_kind = PhysNodeKind::InMemorySource { df: df.clone() }; - if let Some(schema) = output_schema { - phys_node = phys_sm.insert(PhysNode::SimpleProjection { - input: phys_node, - input_schema: input_schema.clone(), - columns: schema.iter_names().map(|s| s.to_string()).collect(), - }) + // Do we need to apply a projection? + if let Some(projection_schema) = projection { + if projection_schema.len() != schema.len() + || projection_schema + .iter_names() + .zip(schema.iter_names()) + .any(|(l, r)| l != r) + { + let phys_input = phys_sm.insert(PhysNode::new(schema, node_kind)); + node_kind = PhysNodeKind::SimpleProjection { + input: phys_input, + columns: projection_schema.iter_names_cloned().collect::>(), + }; + schema = projection_schema.clone(); + } } if let Some(predicate) = filter.clone() { - phys_node = phys_sm.insert(PhysNode::Filter { - input: phys_node, + if !is_elementwise(predicate.node(), expr_arena, expr_cache) { + todo!() + } + + let phys_input = phys_sm.insert(PhysNode::new(schema, node_kind)); + node_kind = PhysNodeKind::Filter { + input: phys_input, predicate, - }) + }; } - Ok(phys_node) + node_kind }, IR::Sink { input, payload } => { if *payload == SinkType::Memory { - let schema = ir_node.schema(ir_arena).into_owned(); - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - return Ok(phys_sm.insert(PhysNode::InMemorySink { input, schema })); + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; + PhysNodeKind::InMemorySink { input: phys_input } + } else { + todo!() } - - todo!() }, IR::MapFunction { input, function } => { - let input_schema = ir_arena.get(*input).schema(ir_arena).into_owned(); let function = function.clone(); - let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; + let phys_input = lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?; - let phys_node = if function.is_streamable() { + if function.is_streamable() { let map = Arc::new(move |df| function.evaluate(df)); - PhysNode::Map { input, map } + PhysNodeKind::Map { + input: phys_input, + map, + } } else { let map = Arc::new(move |df| function.evaluate(df)); - PhysNode::InMemoryMap { - input, - input_schema, + PhysNodeKind::InMemoryMap { + input: phys_input, map, } - }; - - Ok(phys_sm.insert(phys_node)) + } }, IR::Sort { @@ -182,16 +268,18 @@ pub fn lower_ir( by_column, slice, sort_options, - } => { - let input_schema = ir_arena.get(*input).schema(ir_arena).into_owned(); - let phys_node = PhysNode::Sort { - input_schema, - by_column: by_column.clone(), - slice: *slice, - sort_options: sort_options.clone(), - input: lower_ir(*input, ir_arena, expr_arena, phys_sm)?, - }; - Ok(phys_sm.insert(phys_node)) + } => PhysNodeKind::Sort { + by_column: by_column.clone(), + slice: *slice, + sort_options: sort_options.clone(), + input: lower_ir( + *input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + )?, }, IR::Union { inputs, options } => { @@ -202,9 +290,18 @@ pub fn lower_ir( let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| lower_ir(input, ir_arena, expr_arena, phys_sm)) + .map(|input| { + lower_ir( + input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + ) + }) .collect::>()?; - Ok(phys_sm.insert(PhysNode::OrderedUnion { inputs })) + PhysNodeKind::OrderedUnion { inputs } }, IR::HConcat { @@ -212,26 +309,57 @@ pub fn lower_ir( schema: _, options: _, } => { - let input_schemas = inputs - .iter() - .map(|input| { - let input_ir_node = ir_arena.get(*input); - input_ir_node.schema(ir_arena).into_owned() - }) - .collect(); - let inputs = inputs .clone() // Needed to borrow ir_arena mutably. .into_iter() - .map(|input| lower_ir(input, ir_arena, expr_arena, phys_sm)) + .map(|input| { + lower_ir( + input, + ir_arena, + expr_arena, + phys_sm, + schema_cache, + expr_cache, + ) + }) .collect::>()?; - Ok(phys_sm.insert(PhysNode::Zip { + PhysNodeKind::Zip { inputs, - input_schemas, null_extend: true, - })) + } + }, + + v @ IR::Scan { .. } => { + let IR::Scan { + sources, + file_info, + hive_parts, + output_schema, + scan_type, + predicate, + file_options, + } = v.clone() + else { + unreachable!(); + }; + + let paths = sources + .into_paths() + .ok_or_else(|| polars_err!(nyi = "Streaming scanning of in-memory buffers"))?; + + PhysNodeKind::FileScan { + paths, + file_info, + hive_parts, + output_schema, + scan_type, + predicate, + file_options, + } }, _ => todo!(), - } + }; + + Ok(phys_sm.insert(PhysNode::new(output_schema, node_kind))) } diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index ad7d42dd53ed..d22a5f968900 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -1,17 +1,28 @@ +use std::path::PathBuf; use std::sync::Arc; use polars_core::frame::DataFrame; -use polars_core::prelude::SortMultipleOptions; -use polars_core::schema::Schema; -use polars_plan::plans::DataFrameUdf; +use polars_core::prelude::{InitHashMaps, PlHashMap, SortMultipleOptions}; +use polars_core::schema::{Schema, SchemaRef}; +use polars_error::PolarsResult; +use polars_plan::plans::hive::HivePartitions; +use polars_plan::plans::{AExpr, DataFrameUdf, FileInfo, FileScan, IR}; use polars_plan::prelude::expr_ir::ExprIR; +mod fmt; +mod lower_expr; mod lower_ir; mod to_graph; -pub use lower_ir::lower_ir; +pub use fmt::visualize_plan; +use polars_plan::prelude::FileScanOptions; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; +use slotmap::{Key, SecondaryMap, SlotMap}; pub use to_graph::physical_plan_to_graph; +use crate::physical_plan::lower_expr::ExprCache; + slotmap::new_key_type! { /// Key used for PNodes. pub struct PhysNodeKey; @@ -22,7 +33,22 @@ slotmap::new_key_type! { /// A physical plan is created when the `IR` is translated to a directed /// acyclic graph of operations that can run on the streaming engine. #[derive(Clone, Debug)] -pub enum PhysNode { +pub struct PhysNode { + output_schema: Arc, + kind: PhysNodeKind, +} + +impl PhysNode { + pub fn new(output_schema: Arc, kind: PhysNodeKind) -> Self { + Self { + output_schema, + kind, + } + } +} + +#[derive(Clone, Debug)] +pub enum PhysNodeKind { InMemorySource { df: Arc, }, @@ -31,14 +57,11 @@ pub enum PhysNode { input: PhysNodeKey, selectors: Vec, extend_original: bool, - output_schema: Arc, }, Reduce { input: PhysNodeKey, exprs: Vec, - input_schema: Arc, - output_schema: Arc, }, StreamingSlice { @@ -54,18 +77,15 @@ pub enum PhysNode { SimpleProjection { input: PhysNodeKey, - input_schema: Arc, - columns: Vec, + columns: Vec, }, InMemorySink { input: PhysNodeKey, - schema: Arc, }, InMemoryMap { input: PhysNodeKey, - input_schema: Arc, map: Arc, }, @@ -76,7 +96,6 @@ pub enum PhysNode { Sort { input: PhysNodeKey, - input_schema: Arc, // TODO: remove when not using fallback impl. by_column: Vec, slice: Option<(i64, usize)>, sort_options: SortMultipleOptions, @@ -88,10 +107,96 @@ pub enum PhysNode { Zip { inputs: Vec, - input_schemas: Vec>, /// If true shorter inputs are extended with nulls to the longest input, /// if false all inputs must be the same length, or have length 1 in /// which case they are broadcast. null_extend: bool, }, + + #[allow(unused)] + Multiplexer { + input: PhysNodeKey, + }, + + FileScan { + paths: Arc<[PathBuf]>, + file_info: FileInfo, + hive_parts: Option>>, + predicate: Option, + output_schema: Option, + scan_type: FileScan, + file_options: FileScanOptions, + }, +} + +#[recursive::recursive] +fn insert_multiplexers( + node: PhysNodeKey, + phys_sm: &mut SlotMap, + referenced: &mut SecondaryMap, +) { + let seen_before = referenced.insert(node, ()).is_some(); + if seen_before && !matches!(phys_sm[node].kind, PhysNodeKind::Multiplexer { .. }) { + // This node is referenced at least twice. We first set the input key to + // null and then update it to avoid a double-mutable-borrow issue. + let input_schema = phys_sm[node].output_schema.clone(); + let orig_input_node = core::mem::replace( + &mut phys_sm[node], + PhysNode::new( + input_schema, + PhysNodeKind::Multiplexer { + input: PhysNodeKey::null(), + }, + ), + ); + let orig_input_key = phys_sm.insert(orig_input_node); + phys_sm[node].kind = PhysNodeKind::Multiplexer { + input: orig_input_key, + }; + } + + if !seen_before { + match &phys_sm[node].kind { + PhysNodeKind::InMemorySource { .. } | PhysNodeKind::FileScan { .. } => {}, + PhysNodeKind::Select { input, .. } + | PhysNodeKind::Reduce { input, .. } + | PhysNodeKind::StreamingSlice { input, .. } + | PhysNodeKind::Filter { input, .. } + | PhysNodeKind::SimpleProjection { input, .. } + | PhysNodeKind::InMemorySink { input } + | PhysNodeKind::InMemoryMap { input, .. } + | PhysNodeKind::Map { input, .. } + | PhysNodeKind::Sort { input, .. } + | PhysNodeKind::Multiplexer { input } => { + insert_multiplexers(*input, phys_sm, referenced); + }, + + PhysNodeKind::OrderedUnion { inputs } | PhysNodeKind::Zip { inputs, .. } => { + for input in inputs.clone() { + insert_multiplexers(input, phys_sm, referenced); + } + }, + } + } +} + +pub fn build_physical_plan( + root: Node, + ir_arena: &mut Arena, + expr_arena: &mut Arena, + phys_sm: &mut SlotMap, +) -> PolarsResult { + let mut schema_cache = PlHashMap::with_capacity(ir_arena.len()); + let mut expr_cache = ExprCache::with_capacity(expr_arena.len()); + let phys_root = lower_ir::lower_ir( + root, + ir_arena, + expr_arena, + phys_sm, + &mut schema_cache, + &mut expr_cache, + )?; + let mut referenced = SecondaryMap::with_capacity(phys_sm.capacity()); + insert_multiplexers(phys_root, phys_sm, &mut referenced); + Ok(phys_root) } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index b453ab34cd7b..d0bd342b0f65 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -6,14 +6,16 @@ use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, Expressio use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; use polars_mem_engine::create_physical_plan; +use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; use polars_plan::plans::{AExpr, ArenaExprIter, Context, IR}; use polars_plan::prelude::FunctionFlags; use polars_utils::arena::{Arena, Node}; +use polars_utils::itertools::Itertools; use recursive::recursive; use slotmap::{SecondaryMap, SlotMap}; -use super::{PhysNode, PhysNodeKey}; +use super::{PhysNode, PhysNodeKey, PhysNodeKind}; use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; @@ -45,15 +47,16 @@ fn create_stream_expr( struct GraphConversionContext<'a> { phys_sm: &'a SlotMap, - expr_arena: &'a Arena, + expr_arena: &'a mut Arena, graph: Graph, phys_to_graph: SecondaryMap, expr_conversion_state: ExpressionConversionState, } pub fn physical_plan_to_graph( + root: PhysNodeKey, phys_sm: &SlotMap, - expr_arena: &Arena, + expr_arena: &mut Arena, ) -> PolarsResult<(Graph, SecondaryMap)> { let expr_depth_limit = get_expr_depth_limit()?; let mut ctx = GraphConversionContext { @@ -64,9 +67,7 @@ pub fn physical_plan_to_graph( expr_conversion_state: ExpressionConversionState::new(false, expr_depth_limit), }; - for key in phys_sm.keys() { - to_graph_rec(key, &mut ctx)?; - } + to_graph_rec(root, &mut ctx)?; Ok((ctx.graph, ctx.phys_to_graph)) } @@ -81,8 +82,9 @@ fn to_graph_rec<'a>( return Ok(*graph_key); } - use PhysNode::*; - let graph_key = match &ctx.phys_sm[phys_node_key] { + use PhysNodeKind::*; + let node = &ctx.phys_sm[phys_node_key]; + let graph_key = match &node.kind { InMemorySource { df } => ctx.graph.add_node( nodes::in_memory_source::InMemorySourceNode::new(df.clone()), [], @@ -112,7 +114,6 @@ fn to_graph_rec<'a>( Select { selectors, input, - output_schema, extend_original, } => { let phys_selectors = selectors @@ -123,27 +124,21 @@ fn to_graph_rec<'a>( ctx.graph.add_node( nodes::select::SelectNode::new( phys_selectors, - output_schema.clone(), + node.output_schema.clone(), *extend_original, ), [input_key], ) }, - Reduce { - input, - exprs, - input_schema, - output_schema, - } => { + Reduce { input, exprs } => { let input_key = to_graph_rec(*input, ctx)?; + let input_schema = &ctx.phys_sm[*input].output_schema; let mut reductions = Vec::with_capacity(exprs.len()); let mut inputs = Vec::with_capacity(reductions.len()); for e in exprs { - let (red, input_node) = - into_reduction(e.node(), ctx.expr_arena, input_schema.as_ref())? - .expect("invariant"); + let (red, input_node) = into_reduction(e.node(), ctx.expr_arena, input_schema)?; reductions.push(red); let input_phys = @@ -153,41 +148,33 @@ fn to_graph_rec<'a>( } ctx.graph.add_node( - nodes::reduce::ReduceNode::new(inputs, reductions, output_schema.clone()), + nodes::reduce::ReduceNode::new(inputs, reductions, node.output_schema.clone()), [input_key], ) }, - SimpleProjection { - input, - columns, - input_schema, - } => { + SimpleProjection { input, columns } => { + let input_schema = ctx.phys_sm[*input].output_schema.clone(); let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( - nodes::simple_projection::SimpleProjectionNode::new( - columns.clone(), - input_schema.clone(), - ), + nodes::simple_projection::SimpleProjectionNode::new(columns.clone(), input_schema), [input_key], ) }, - InMemorySink { input, schema } => { + InMemorySink { input } => { + let input_schema = ctx.phys_sm[*input].output_schema.clone(); let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( - nodes::in_memory_sink::InMemorySinkNode::new(schema.clone()), + nodes::in_memory_sink::InMemorySinkNode::new(input_schema), [input_key], ) }, - InMemoryMap { - input, - input_schema, - map, - } => { + InMemoryMap { input, map } => { + let input_schema = ctx.phys_sm[*input].output_schema.clone(); let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( - nodes::in_memory_map::InMemoryMapNode::new(input_schema.clone(), map.clone()), + nodes::in_memory_map::InMemoryMapNode::new(input_schema, map.clone()), [input_key], ) }, @@ -200,11 +187,11 @@ fn to_graph_rec<'a>( Sort { input, - input_schema, by_column, slice, sort_options, } => { + let input_schema = ctx.phys_sm[*input].output_schema.clone(); let lmdf = Arc::new(LateMaterializedDataFrame::default()); let mut lp_arena = Arena::default(); let df_node = lp_arena.add(lmdf.clone().as_ir_node(input_schema.clone())); @@ -223,7 +210,7 @@ fn to_graph_rec<'a>( let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( nodes::in_memory_map::InMemoryMapNode::new( - input_schema.clone(), + input_schema, Arc::new(move |df| { lmdf.set_materialized_dataframe(df); let mut state = ExecutionState::new(); @@ -245,18 +232,90 @@ fn to_graph_rec<'a>( Zip { inputs, - input_schemas, null_extend, } => { + let input_schemas = inputs + .iter() + .map(|i| ctx.phys_sm[*i].output_schema.clone()) + .collect_vec(); let input_keys = inputs .iter() .map(|i| to_graph_rec(*i, ctx)) - .collect::, _>>()?; + .try_collect_vec()?; ctx.graph.add_node( - nodes::zip::ZipNode::new(*null_extend, input_schemas.clone()), + nodes::zip::ZipNode::new(*null_extend, input_schemas), input_keys, ) }, + + Multiplexer { input } => { + let input_key = to_graph_rec(*input, ctx)?; + ctx.graph + .add_node(nodes::multiplexer::MultiplexerNode::new(), [input_key]) + }, + + v @ FileScan { .. } => { + let FileScan { + paths, + file_info, + hive_parts, + output_schema, + scan_type, + predicate, + mut file_options, + } = v.clone() + else { + unreachable!() + }; + + file_options.slice = if let Some((offset, len)) = file_options.slice { + Some((offset, _set_n_rows_for_scan(Some(len)).unwrap())) + } else { + _set_n_rows_for_scan(None).map(|x| (0, x)) + }; + + let predicate = predicate + .map(|pred| { + create_physical_expr( + &pred, + Context::Default, + ctx.expr_arena, + output_schema.as_ref(), + &mut ctx.expr_conversion_state, + ) + }) + .map_or(Ok(None), |v| v.map(Some))?; + + { + use polars_plan::prelude::FileScan; + + match scan_type { + FileScan::Parquet { + options, + cloud_options, + metadata: _, + } => { + if std::env::var("POLARS_DISABLE_PARQUET_SOURCE").as_deref() != Ok("1") { + ctx.graph.add_node( + nodes::parquet_source::ParquetSourceNode::new( + paths, + file_info, + hive_parts, + predicate, + options, + cloud_options, + file_options, + ), + [], + ) + } else { + todo!() + } + }, + _ => todo!(), + } + } + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key); diff --git a/crates/polars-stream/src/skeleton.rs b/crates/polars-stream/src/skeleton.rs index 64fcdc4d5c5e..20ca189de9e0 100644 --- a/crates/polars-stream/src/skeleton.rs +++ b/crates/polars-stream/src/skeleton.rs @@ -2,7 +2,7 @@ use polars_core::prelude::*; use polars_core::POOL; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; -use polars_plan::plans::{Context, IR}; +use polars_plan::plans::{Context, IRPlan, IR}; use polars_plan::prelude::expr_ir::ExprIR; use polars_plan::prelude::AExpr; use polars_utils::arena::{Arena, Node}; @@ -15,13 +15,26 @@ fn is_streamable(node: Node, arena: &Arena) -> bool { pub fn run_query( node: Node, mut ir_arena: Arena, - expr_arena: &Arena, + expr_arena: &mut Arena, ) -> PolarsResult { + if let Ok(visual_path) = std::env::var("POLARS_VISUALIZE_IR") { + let plan = IRPlan { + lp_top: node, + lp_arena: ir_arena.clone(), + expr_arena: expr_arena.clone(), + }; + let visualization = plan.display_dot().to_string(); + std::fs::write(visual_path, visualization).unwrap(); + } let mut phys_sm = SlotMap::with_capacity_and_key(ir_arena.len()); - - let root = crate::physical_plan::lower_ir(node, &mut ir_arena, expr_arena, &mut phys_sm)?; + let root = + crate::physical_plan::build_physical_plan(node, &mut ir_arena, expr_arena, &mut phys_sm)?; + if let Ok(visual_path) = std::env::var("POLARS_VISUALIZE_PHYSICAL_PLAN") { + let visualization = crate::physical_plan::visualize_plan(root, &phys_sm, expr_arena); + std::fs::write(visual_path, visualization).unwrap(); + } let (mut graph, phys_to_graph) = - crate::physical_plan::physical_plan_to_graph(&phys_sm, expr_arena)?; + crate::physical_plan::physical_plan_to_graph(root, &phys_sm, expr_arena)?; let mut results = crate::execute::execute_graph(&mut graph)?; Ok(results.remove(phys_to_graph[root]).unwrap()) } diff --git a/crates/polars-stream/src/utils/late_materialized_df.rs b/crates/polars-stream/src/utils/late_materialized_df.rs index 2173598d5369..b18c5cea0657 100644 --- a/crates/polars-stream/src/utils/late_materialized_df.rs +++ b/crates/polars-stream/src/utils/late_materialized_df.rs @@ -4,7 +4,7 @@ use parking_lot::Mutex; use polars_core::frame::DataFrame; use polars_core::schema::Schema; use polars_error::PolarsResult; -use polars_plan::plans::{AnonymousScan, AnonymousScanArgs, FileInfo, FileScan, IR}; +use polars_plan::plans::{AnonymousScan, AnonymousScanArgs, FileInfo, FileScan, ScanSources, IR}; use polars_plan::prelude::{AnonymousScanOptions, FileScanOptions}; /// Used to insert a dataframe into in-memory-engine query plan after the query @@ -25,7 +25,7 @@ impl LateMaterializedDataFrame { fmt_str: "LateMaterializedDataFrame", }); IR::Scan { - paths: Arc::new(vec![]), + sources: ScanSources::Paths(Arc::default()), file_info: FileInfo::new(schema, None, (None, usize::MAX)), hive_parts: None, predicate: None, diff --git a/crates/polars-stream/src/utils/mod.rs b/crates/polars-stream/src/utils/mod.rs index 018b893ea992..4d16cd5499e3 100644 --- a/crates/polars-stream/src/utils/mod.rs +++ b/crates/polars-stream/src/utils/mod.rs @@ -1,3 +1,4 @@ pub mod in_memory_linearize; pub mod late_materialized_df; pub mod linearizer; +pub mod task_handles_ext; diff --git a/crates/polars-stream/src/utils/task_handles_ext.rs b/crates/polars-stream/src/utils/task_handles_ext.rs new file mode 100644 index 000000000000..edeca1558e80 --- /dev/null +++ b/crates/polars-stream/src/utils/task_handles_ext.rs @@ -0,0 +1,20 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// Calls [`tokio::task::JoinHandle::abort`] on the join handle when dropped. +pub struct AbortOnDropHandle(pub tokio::task::JoinHandle); + +impl Future for AbortOnDropHandle { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl Drop for AbortOnDropHandle { + fn drop(&mut self) { + self.0.abort(); + } +} diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index 9fa609614c59..cbddfbf4eba8 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -23,7 +23,6 @@ now = { version = "0.1" } once_cell = { workspace = true } regex = { workspace = true } serde = { workspace = true, optional = true } -smartstring = { workspace = true } [dev-dependencies] polars-ops = { workspace = true, features = ["abs"] } @@ -34,12 +33,12 @@ dtype-datetime = ["polars-core/dtype-datetime", "temporal"] dtype-time = ["polars-core/dtype-time", "temporal"] dtype-duration = ["polars-core/dtype-duration", "temporal"] month_start = [] -month_end = [] +month_end = ["month_start"] offset_by = [] rolling_window = ["polars-core/rolling_window"] rolling_window_by = ["polars-core/rolling_window_by", "dtype-duration"] fmt = ["polars-core/fmt"] -serde = ["dep:serde", "smartstring/serde"] +serde = ["dep:serde"] temporal = ["polars-core/temporal"] timezones = ["chrono-tz", "dtype-datetime", "polars-core/timezones", "arrow/timezones", "polars-ops/timezones"] diff --git a/crates/polars-time/src/chunkedarray/date.rs b/crates/polars-time/src/chunkedarray/date.rs index 402f01c43017..8132f1ea2bba 100644 --- a/crates/polars-time/src/chunkedarray/date.rs +++ b/crates/polars-time/src/chunkedarray/date.rs @@ -73,11 +73,11 @@ pub trait DateMethods: AsDate { ca.apply_kernel_cast::(&date_to_ordinal) } - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> DateChunked; + fn parse_from_str_slice(name: PlSmallStr, v: &[&str], fmt: &str) -> DateChunked; } impl DateMethods for DateChunked { - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> DateChunked { + fn parse_from_str_slice(name: PlSmallStr, v: &[&str], fmt: &str) -> DateChunked { Int32Chunked::from_iter_options( name, v.iter().map(|s| { diff --git a/crates/polars-time/src/chunkedarray/datetime.rs b/crates/polars-time/src/chunkedarray/datetime.rs index de14c83c6e72..0e4adf3094b8 100644 --- a/crates/polars-time/src/chunkedarray/datetime.rs +++ b/crates/polars-time/src/chunkedarray/datetime.rs @@ -25,7 +25,7 @@ fn cast_and_apply< .unwrap(); func(&*arr).unwrap() }); - ChunkedArray::from_chunk_iter(ca.name(), chunks) + ChunkedArray::from_chunk_iter(ca.name().clone(), chunks) } pub trait DatetimeMethods: AsDatetime { @@ -130,7 +130,12 @@ pub trait DatetimeMethods: AsDatetime { ca.apply_kernel_cast::(&f) } - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str, tu: TimeUnit) -> DatetimeChunked { + fn parse_from_str_slice( + name: PlSmallStr, + v: &[&str], + fmt: &str, + tu: TimeUnit, + ) -> DatetimeChunked { let func = match tu { TimeUnit::Nanoseconds => datetime_to_timestamp_ns, TimeUnit::Microseconds => datetime_to_timestamp_us, @@ -175,7 +180,7 @@ mod test { // NOTE: the values are checked and correct. let dt = DatetimeChunked::from_naive_datetime( - "name", + "name".into(), datetimes.iter().copied(), TimeUnit::Nanoseconds, ); diff --git a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs index 652629c336a4..167035a9c8bf 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/dispatch.rs @@ -32,7 +32,7 @@ where { polars_ensure!(options.min_periods <= options.window_size, InvalidOperation: "`min_periods` should be <= `window_size`"); if ca.is_empty() { - return Ok(Series::new_empty(ca.name(), ca.dtype())); + return Ok(Series::new_empty(ca.name().clone(), ca.dtype())); } let ca = ca.rechunk(); @@ -55,7 +55,7 @@ where options.fn_params, ), }; - Series::try_from((ca.name(), arr)) + Series::try_from((ca.name().clone(), arr)) } #[cfg(feature = "rolling_window_by")] @@ -80,11 +80,11 @@ where T: PolarsNumericType, { if ca.is_empty() { - return Ok(Series::new_empty(ca.name(), ca.dtype())); + return Ok(Series::new_empty(ca.name().clone(), ca.dtype())); } polars_ensure!(by.null_count() == 0 && ca.null_count() == 0, InvalidOperation: "'Expr.rolling_*_by(...)' not yet supported for series with null values, consider using 'DataFrame.rolling' or 'Expr.rolling'"); polars_ensure!(ca.len() == by.len(), InvalidOperation: "`by` column in `rolling_*_by` must be the same length as values column"); - ensure_duration_matches_data_type(options.window_size, by.dtype(), "window_size")?; + ensure_duration_matches_dtype(options.window_size, by.dtype(), "window_size")?; polars_ensure!(!options.window_size.is_zero() && !options.window_size.negative, InvalidOperation: "`window_size` must be strictly positive"); let (by, tz) = match by.dtype() { DataType::Datetime(tu, tz) => (by.cast(&DataType::Datetime(*tu, None))?, tz), @@ -141,7 +141,7 @@ where Some(sorting_indices.cont_slice().unwrap()), )? }; - Series::try_from((ca.name(), out)) + Series::try_from((ca.name().clone(), out)) } pub trait SeriesOpsTime: AsSeries { diff --git a/crates/polars-time/src/chunkedarray/string/infer.rs b/crates/polars-time/src/chunkedarray/string/infer.rs index 5f0f3d7daf96..f91d0ab40869 100644 --- a/crates/polars-time/src/chunkedarray/string/infer.rs +++ b/crates/polars-time/src/chunkedarray/string/infer.rs @@ -325,11 +325,11 @@ where .map(|opt_val| opt_val.and_then(|val| self.parse(val))); PrimitiveArray::from_trusted_len_iter(iter) }); - ChunkedArray::from_chunk_iter(ca.name(), chunks) + ChunkedArray::from_chunk_iter(ca.name().clone(), chunks) .into_series() .cast(&self.logical_type) .unwrap() - .with_name(ca.name()) + .with_name(ca.name().clone()) } } @@ -444,7 +444,9 @@ pub(crate) fn to_datetime( _ambiguous: &StringChunked, ) -> PolarsResult { match ca.first_non_null() { - None => Ok(Int64Chunked::full_null(ca.name(), ca.len()).into_datetime(tu, tz.cloned())), + None => { + Ok(Int64Chunked::full_null(ca.name().clone(), ca.len()).into_datetime(tu, tz.cloned())) + }, Some(idx) => { let subset = ca.slice(idx as i64, ca.len()); let pattern = subset @@ -459,7 +461,8 @@ pub(crate) fn to_datetime( // `tz` has already been validated. ca.set_time_unit_and_time_zone( tu, - tz.cloned().unwrap_or_else(|| "UTC".to_string()), + tz.cloned() + .unwrap_or_else(|| PlSmallStr::from_static("UTC")), )?; Ok(ca) })?, @@ -484,7 +487,7 @@ pub(crate) fn to_datetime( #[cfg(feature = "dtype-date")] pub(crate) fn to_date(ca: &StringChunked) -> PolarsResult { match ca.first_non_null() { - None => Ok(Int32Chunked::full_null(ca.name(), ca.len()).into_date()), + None => Ok(Int32Chunked::full_null(ca.name().clone(), ca.len()).into_date()), Some(idx) => { let subset = ca.slice(idx as i64, ca.len()); let pattern = subset diff --git a/crates/polars-time/src/chunkedarray/string/mod.rs b/crates/polars-time/src/chunkedarray/string/mod.rs index 42221b4861f1..f48faee8dd89 100644 --- a/crates/polars-time/src/chunkedarray/string/mod.rs +++ b/crates/polars-time/src/chunkedarray/string/mod.rs @@ -108,7 +108,7 @@ pub trait StringMethods: AsString { (string_ca.len() as f64).sqrt() as usize, ); let ca = unary_elementwise(string_ca, |opt_s| convert.eval(opt_s?, use_cache)); - Ok(ca.with_name(string_ca.name()).into()) + Ok(ca.with_name(string_ca.name().clone()).into()) } #[cfg(feature = "dtype-date")] @@ -143,7 +143,7 @@ pub trait StringMethods: AsString { } None }); - Ok(ca.with_name(string_ca.name()).into()) + Ok(ca.with_name(string_ca.name().clone()).into()) } #[cfg(feature = "dtype-datetime")] @@ -200,7 +200,7 @@ pub trait StringMethods: AsString { } None }) - .with_name(string_ca.name()); + .with_name(string_ca.name().clone()); match (tz_aware, tz) { #[cfg(feature = "timezones")] (false, Some(tz)) => polars_ops::prelude::replace_time_zone( @@ -210,7 +210,10 @@ pub trait StringMethods: AsString { NonExistent::Raise, ), #[cfg(feature = "timezones")] - (true, tz) => Ok(ca.into_datetime(tu, tz.cloned().or_else(|| Some("UTC".to_string())))), + (true, tz) => Ok(ca.into_datetime( + tu, + tz.cloned().or_else(|| Some(PlSmallStr::from_static("UTC"))), + )), _ => Ok(ca.into_datetime(tu, None)), } } @@ -253,7 +256,7 @@ pub trait StringMethods: AsString { unary_elementwise(string_ca, |val| convert.eval(val?, use_cache)) }; - Ok(ca.with_name(string_ca.name()).into()) + Ok(ca.with_name(string_ca.name().clone()).into()) } #[cfg(feature = "dtype-datetime")] @@ -293,10 +296,13 @@ pub trait StringMethods: AsString { ); Ok( unary_elementwise(string_ca, |opt_s| convert.eval(opt_s?, use_cache)) - .with_name(string_ca.name()) + .with_name(string_ca.name().clone()) .into_datetime( tu, - Some(tz.map(|x| x.to_string()).unwrap_or("UTC".to_string())), + Some( + tz.cloned() + .unwrap_or_else(|| PlSmallStr::from_static("UTC")), + ), ), ) } @@ -332,7 +338,9 @@ pub trait StringMethods: AsString { ); unary_elementwise(string_ca, |opt_s| convert.eval(opt_s?, use_cache)) }; - let dt = ca.with_name(string_ca.name()).into_datetime(tu, None); + let dt = ca + .with_name(string_ca.name().clone()) + .into_datetime(tu, None); match tz { #[cfg(feature = "timezones")] Some(tz) => polars_ops::prelude::replace_time_zone( diff --git a/crates/polars-time/src/chunkedarray/time.rs b/crates/polars-time/src/chunkedarray/time.rs index 6f3c1ab10c51..c0d267202e12 100644 --- a/crates/polars-time/src/chunkedarray/time.rs +++ b/crates/polars-time/src/chunkedarray/time.rs @@ -20,7 +20,7 @@ pub trait TimeMethods { /// The range from 1,000,000,000 to 1,999,999,999 represents the leap second. fn nanosecond(&self) -> Int32Chunked; - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> TimeChunked; + fn parse_from_str_slice(name: PlSmallStr, v: &[&str], fmt: &str) -> TimeChunked; } impl TimeMethods for TimeChunked { @@ -49,7 +49,7 @@ impl TimeMethods for TimeChunked { self.apply_kernel_cast::(&time_to_nanosecond) } - fn parse_from_str_slice(name: &str, v: &[&str], fmt: &str) -> TimeChunked { + fn parse_from_str_slice(name: PlSmallStr, v: &[&str], fmt: &str) -> TimeChunked { v.iter() .map(|s| { NaiveTime::parse_from_str(s, fmt) diff --git a/crates/polars-time/src/date_range.rs b/crates/polars-time/src/date_range.rs index b44afebddb32..8f01d687fd83 100644 --- a/crates/polars-time/src/date_range.rs +++ b/crates/polars-time/src/date_range.rs @@ -3,6 +3,7 @@ use chrono::{Datelike, NaiveDateTime, NaiveTime}; use polars_core::chunked_array::temporal::time_to_time64ns; use polars_core::prelude::*; use polars_core::series::IsSorted; +use polars_utils::format_pl_smallstr; use crate::prelude::*; @@ -13,7 +14,7 @@ pub fn in_nanoseconds_window(ndt: &NaiveDateTime) -> bool { /// Create a [`DatetimeChunked`] from a given `start` and `end` date and a given `interval`. pub fn date_range( - name: &str, + name: PlSmallStr, start: NaiveDateTime, end: NaiveDateTime, interval: Duration, @@ -40,7 +41,7 @@ pub fn date_range( #[doc(hidden)] pub fn datetime_range_impl( - name: &str, + name: PlSmallStr, start: i64, end: i64, interval: Duration, @@ -54,7 +55,7 @@ pub fn datetime_range_impl( ); let mut out = match tz { #[cfg(feature = "timezones")] - Some(tz) => out.into_datetime(tu, Some(tz.to_string())), + Some(tz) => out.into_datetime(tu, Some(format_pl_smallstr!("{}", tz))), _ => out.into_datetime(tu, None), }; @@ -64,7 +65,7 @@ pub fn datetime_range_impl( /// Create a [`TimeChunked`] from a given `start` and `end` date and a given `interval`. pub fn time_range( - name: &str, + name: PlSmallStr, start: NaiveTime, end: NaiveTime, interval: Duration, @@ -77,7 +78,7 @@ pub fn time_range( #[doc(hidden)] pub fn time_range_impl( - name: &str, + name: PlSmallStr, start: i64, end: i64, interval: Duration, diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index 659a02ab158c..9e428794365e 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -7,10 +7,10 @@ use polars_core::utils::flatten::flatten_par; use polars_core::POOL; use polars_ops::series::SeriesMethods; use polars_utils::idx_vec::IdxVec; +use polars_utils::pl_str::PlSmallStr; use polars_utils::slice::{GetSaferUnchecked, SortedSlice}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use smartstring::alias::String as SmartString; use crate::prelude::*; @@ -21,7 +21,7 @@ struct Wrap(pub T); #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct DynamicGroupOptions { /// Time or index column. - pub index_column: SmartString, + pub index_column: PlSmallStr, /// Start a window at this interval. pub every: Duration, /// Window duration. @@ -55,7 +55,7 @@ impl Default for DynamicGroupOptions { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RollingGroupOptions { /// Time or index column. - pub index_column: SmartString, + pub index_column: PlSmallStr, /// Window duration. pub period: Duration, pub offset: Duration, @@ -133,8 +133,8 @@ impl Wrap<&DataFrame> { let time_type = time.dtype(); polars_ensure!(time.null_count() == 0, ComputeError: "null values in `rolling` not supported, fill nulls."); - ensure_duration_matches_data_type(options.period, time_type, "period")?; - ensure_duration_matches_data_type(options.offset, time_type, "offset")?; + ensure_duration_matches_dtype(options.period, time_type, "period")?; + ensure_duration_matches_dtype(options.offset, time_type, "offset")?; use DataType::*; let (dt, tu, tz): (Series, TimeUnit, Option) = match time_type { @@ -202,9 +202,9 @@ impl Wrap<&DataFrame> { let time_type = time.dtype(); polars_ensure!(time.null_count() == 0, ComputeError: "null values in dynamic group_by not supported, fill nulls."); - ensure_duration_matches_data_type(options.every, time_type, "every")?; - ensure_duration_matches_data_type(options.offset, time_type, "offset")?; - ensure_duration_matches_data_type(options.period, time_type, "period")?; + ensure_duration_matches_dtype(options.every, time_type, "every")?; + ensure_duration_matches_dtype(options.offset, time_type, "offset")?; + ensure_duration_matches_dtype(options.period, time_type, "period")?; use DataType::*; let (dt, tu) = match time_type { @@ -225,7 +225,7 @@ impl Wrap<&DataFrame> { )?; let out = out.cast(&Int64).unwrap().cast(&Int32).unwrap(); for k in &mut keys { - if k.name() == UP_NAME || k.name() == LB_NAME { + if k.name().as_str() == UP_NAME || k.name().as_str() == LB_NAME { *k = k.cast(&Int64).unwrap().cast(&Int32).unwrap() } } @@ -243,7 +243,7 @@ impl Wrap<&DataFrame> { )?; let out = out.cast(&Int64).unwrap(); for k in &mut keys { - if k.name() == UP_NAME || k.name() == LB_NAME { + if k.name().as_str() == UP_NAME || k.name().as_str() == LB_NAME { *k = k.cast(&Int64).unwrap() } } @@ -476,21 +476,23 @@ impl Wrap<&DataFrame> { *key = unsafe { key.agg_first(&groups) }; } - let lower = lower_bound.map(|lower| Int64Chunked::new_vec(LB_NAME, lower)); - let upper = upper_bound.map(|upper| Int64Chunked::new_vec(UP_NAME, upper)); + let lower = + lower_bound.map(|lower| Int64Chunked::new_vec(PlSmallStr::from_static(LB_NAME), lower)); + let upper = + upper_bound.map(|upper| Int64Chunked::new_vec(PlSmallStr::from_static(UP_NAME), upper)); if options.label == Label::Left { let mut lower = lower.clone().unwrap(); if by.is_empty() { lower.set_sorted_flag(IsSorted::Ascending) } - dt = lower.with_name(dt.name()); + dt = lower.with_name(dt.name().clone()); } else if options.label == Label::Right { let mut upper = upper.clone().unwrap(); if by.is_empty() { upper.set_sorted_flag(IsSorted::Ascending) } - dt = upper.with_name(dt.name()); + dt = upper.with_name(dt.name().clone()); } if let (true, Some(mut lower), Some(mut upper)) = (options.include_boundaries, lower, upper) @@ -671,7 +673,7 @@ mod test { TimeUnit::Milliseconds, ] { let mut date = StringChunked::new( - "dt", + "dt".into(), [ "2020-01-01 13:45:48", "2020-01-01 16:42:13", @@ -691,7 +693,7 @@ mod test { )? .into_series(); date.set_sorted_flag(IsSorted::Ascending); - let a = Series::new("a", [3, 7, 5, 9, 2, 1]); + let a = Series::new("a".into(), [3, 7, 5, 9, 2, 1]); let df = DataFrame::new(vec![date, a.clone()])?; let (_, _, groups) = df @@ -707,7 +709,7 @@ mod test { .unwrap(); let sum = unsafe { a.agg_sum(&groups) }; - let expected = Series::new("", [3, 10, 15, 24, 11, 1]); + let expected = Series::new("".into(), [3, 10, 15, 24, 11, 1]); assert_eq!(sum, expected); } @@ -717,7 +719,7 @@ mod test { #[test] fn test_rolling_group_by_aggs() -> PolarsResult<()> { let mut date = StringChunked::new( - "dt", + "dt".into(), [ "2020-01-01 13:45:48", "2020-01-01 16:42:13", @@ -738,7 +740,7 @@ mod test { .into_series(); date.set_sorted_flag(IsSorted::Ascending); - let a = Series::new("a", [3, 7, 5, 9, 2, 1]); + let a = Series::new("a".into(), [3, 7, 5, 9, 2, 1]); let df = DataFrame::new(vec![date, a.clone()])?; let (_, _, groups) = df @@ -753,10 +755,13 @@ mod test { ) .unwrap(); - let nulls = Series::new("", [Some(3), Some(7), None, Some(9), Some(2), Some(1)]); + let nulls = Series::new( + "".into(), + [Some(3), Some(7), None, Some(9), Some(2), Some(1)], + ); let min = unsafe { a.agg_min(&groups) }; - let expected = Series::new("", [3, 3, 3, 3, 2, 1]); + let expected = Series::new("".into(), [3, 3, 3, 3, 2, 1]); assert_eq!(min, expected); // Expected for nulls is equality. @@ -764,7 +769,7 @@ mod test { assert_eq!(min, expected); let max = unsafe { a.agg_max(&groups) }; - let expected = Series::new("", [3, 7, 7, 9, 9, 1]); + let expected = Series::new("".into(), [3, 7, 7, 9, 9, 1]); assert_eq!(max, expected); let max = unsafe { nulls.agg_max(&groups) }; @@ -772,21 +777,21 @@ mod test { let var = unsafe { a.agg_var(&groups, 1) }; let expected = Series::new( - "", + "".into(), [0.0, 8.0, 4.000000000000002, 6.666666666666667, 24.5, 0.0], ); assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all()); let var = unsafe { nulls.agg_var(&groups, 1) }; - let expected = Series::new("", [0.0, 8.0, 8.0, 9.333333333333343, 24.5, 0.0]); + let expected = Series::new("".into(), [0.0, 8.0, 8.0, 9.333333333333343, 24.5, 0.0]); assert!(abs(&(var - expected)?).unwrap().lt(1e-12).unwrap().all()); let quantile = unsafe { a.agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) }; - let expected = Series::new("", [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]); + let expected = Series::new("".into(), [3.0, 5.0, 5.0, 6.0, 5.5, 1.0]); assert_eq!(quantile, expected); let quantile = unsafe { nulls.agg_quantile(&groups, 0.5, QuantileInterpolOptions::Linear) }; - let expected = Series::new("", [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]); + let expected = Series::new("".into(), [3.0, 5.0, 5.0, 7.0, 5.5, 1.0]); assert_eq!(quantile, expected); Ok(()) @@ -807,7 +812,7 @@ mod test { .and_utc() .timestamp_millis(); let range = datetime_range_impl( - "date", + "date".into(), start, stop, Duration::parse("30m"), @@ -817,7 +822,7 @@ mod test { )? .into_series(); - let groups = Series::new("groups", ["a", "a", "a", "b", "b", "a", "a"]); + let groups = Series::new("groups".into(), ["a", "a", "a", "b", "b", "a", "a"]); let df = DataFrame::new(vec![range, groups.clone()]).unwrap(); let (time_key, mut keys, groups) = df @@ -861,7 +866,7 @@ mod test { .and_utc() .timestamp_millis(); let range = datetime_range_impl( - "_upper_boundary", + "_upper_boundary".into(), start, stop, Duration::parse("1h"), @@ -886,7 +891,7 @@ mod test { .and_utc() .timestamp_millis(); let range = datetime_range_impl( - "_lower_boundary", + "_lower_boundary".into(), start, stop, Duration::parse("1h"), @@ -927,7 +932,7 @@ mod test { .and_utc() .timestamp_millis(); let range = datetime_range_impl( - "date", + "date".into(), start, stop, Duration::parse("1d"), @@ -937,7 +942,7 @@ mod test { )? .into_series(); - let groups = Series::new("groups", ["a", "a", "a", "b", "b", "a", "a"]); + let groups = Series::new("groups".into(), ["a", "a", "a", "b", "b", "a", "a"]); let df = DataFrame::new(vec![range, groups.clone()]).unwrap(); let (mut time_key, keys, _groups) = df @@ -955,8 +960,8 @@ mod test { }, ) .unwrap(); - time_key.rename(""); - let lower_bound = keys[1].clone().with_name(""); + time_key.rename("".into()); + let lower_bound = keys[1].clone().with_name("".into()); assert!(time_key.equals(&lower_bound)); Ok(()) } diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index 4bb6f2a3386f..f67c509c5dc2 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -5,6 +5,12 @@ use polars_core::prelude::*; use polars_utils::cache::FastFixedCache; use crate::prelude::*; +use crate::truncate::fast_truncate; + +#[inline(always)] +fn fast_round(t: i64, every: i64) -> i64 { + fast_truncate(t + every / 2, every) +} pub trait PolarsRound { fn round(&self, every: &StringChunked, tz: Option<&Tz>) -> PolarsResult @@ -35,11 +41,7 @@ impl PolarsRound for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - // Round half-way values away from zero - let half_away = t.signum() * every / 2; - t + half_away - (t + half_away) % every - }) + .apply_values(|t| fast_round(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); @@ -57,7 +59,7 @@ impl PolarsRound for DatetimeChunked { return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())); } } else { - return Ok(Int64Chunked::full_null(self.name(), self.len()) + return Ok(Int64Chunked::full_null(self.name().clone(), self.len()) .into_datetime(self.time_unit(), self.time_zone().clone())); } } @@ -110,7 +112,7 @@ impl PolarsRound for DateChunked { ) }) } else { - Ok(Int32Chunked::full_null(self.name(), self.len())) + Ok(Int32Chunked::full_null(self.name().clone(), self.len())) } }, _ => broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| { diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index 991ce50b547a..0548911e0fbf 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -12,6 +12,12 @@ pub trait PolarsTruncate { Self: Sized; } +#[inline(always)] +pub(crate) fn fast_truncate(t: i64, every: i64) -> i64 { + let remainder = t % every; + t - (remainder + every * (remainder < 0) as i64) +} + impl PolarsTruncate for DatetimeChunked { fn truncate(&self, tz: Option<&Tz>, every: &StringChunked) -> PolarsResult { let time_zone = self.time_zone(); @@ -35,10 +41,7 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - let remainder = t % every; - t - (remainder + every * (remainder < 0) as i64) - }) + .apply_values(|t| fast_truncate(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); @@ -56,7 +59,7 @@ impl PolarsTruncate for DatetimeChunked { return Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())); } } else { - return Ok(Int64Chunked::full_null(self.name(), self.len()) + return Ok(Int64Chunked::full_null(self.name().clone(), self.len()) .into_datetime(self.time_unit(), self.time_zone().clone())); } } @@ -107,7 +110,7 @@ impl PolarsTruncate for DateChunked { / MILLISECONDS_IN_DAY) as i32) }) } else { - Ok(Int32Chunked::full_null(self.name(), self.len())) + Ok(Int32Chunked::full_null(self.name().clone(), self.len())) } }, _ => broadcast_try_binary_elementwise(self, every, |opt_t, opt_every| { diff --git a/crates/polars-time/src/upsample.rs b/crates/polars-time/src/upsample.rs index 692f1a35744c..37119317ccfa 100644 --- a/crates/polars-time/src/upsample.rs +++ b/crates/polars-time/src/upsample.rs @@ -38,7 +38,7 @@ pub trait PolarsUpsample { /// day (which may not be 24 hours, depending on daylight savings). /// Similarly for "calendar week", "calendar month", "calendar quarter", /// and "calendar year". - fn upsample>( + fn upsample>( &self, by: I, time_column: &str, @@ -79,7 +79,7 @@ pub trait PolarsUpsample { /// day (which may not be 24 hours, depending on daylight savings). /// Similarly for "calendar week", "calendar month", "calendar quarter", /// and "calendar year". - fn upsample_stable>( + fn upsample_stable>( &self, by: I, time_column: &str, @@ -88,7 +88,7 @@ pub trait PolarsUpsample { } impl PolarsUpsample for DataFrame { - fn upsample>( + fn upsample>( &self, by: I, time_column: &str, @@ -96,11 +96,11 @@ impl PolarsUpsample for DataFrame { ) -> PolarsResult { let by = by.into_vec(); let time_type = self.column(time_column)?.dtype(); - ensure_duration_matches_data_type(every, time_type, "every")?; + ensure_duration_matches_dtype(every, time_type, "every")?; upsample_impl(self, by, time_column, every, false) } - fn upsample_stable>( + fn upsample_stable>( &self, by: I, time_column: &str, @@ -108,20 +108,19 @@ impl PolarsUpsample for DataFrame { ) -> PolarsResult { let by = by.into_vec(); let time_type = self.column(time_column)?.dtype(); - ensure_duration_matches_data_type(every, time_type, "every")?; + ensure_duration_matches_dtype(every, time_type, "every")?; upsample_impl(self, by, time_column, every, true) } } fn upsample_impl( source: &DataFrame, - by: Vec, + by: Vec, index_column: &str, every: Duration, stable: bool, ) -> PolarsResult { let s = source.column(index_column)?; - s.ensure_sorted_arg("upsample")?; let time_type = s.dtype(); if matches!(time_type, DataType::Date) { let mut df = source.clone(); @@ -184,6 +183,7 @@ fn upsample_single_impl( index_column: &Series, every: Duration, ) -> PolarsResult { + index_column.ensure_sorted_arg("upsample")?; let index_col_name = index_column.name(); use DataType::*; @@ -201,7 +201,7 @@ fn upsample_single_impl( _ => None, }; let range = datetime_range_impl( - index_col_name, + index_col_name.clone(), first, last, every, @@ -213,8 +213,8 @@ fn upsample_single_impl( .into_frame(); range.join( source, - &[index_col_name], - &[index_col_name], + [index_col_name.clone()], + [index_col_name.clone()], JoinArgs::new(JoinType::Left), ) }, diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index a459a935107b..4f300f733100 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -1031,12 +1031,12 @@ pub fn ensure_is_constant_duration( Ok(()) } -pub fn ensure_duration_matches_data_type( +pub fn ensure_duration_matches_dtype( duration: Duration, - data_type: &DataType, + dtype: &DataType, variable_name: &str, ) -> PolarsResult<()> { - match data_type { + match dtype { DataType::Int64 | DataType::UInt64 | DataType::Int32 | DataType::UInt32 => { polars_ensure!(duration.parsed_int || duration.is_zero(), InvalidOperation: "`{}` duration must be a parsed integer (i.e. use '2i', not '2d') when working with a numeric column", variable_name); @@ -1046,7 +1046,7 @@ pub fn ensure_duration_matches_data_type( InvalidOperation: "`{}` duration may not be a parsed integer (i.e. use '2d', not '2i') when working with a temporal column", variable_name); }, _ => { - polars_bail!(InvalidOperation: "unsupported data type: {} for `{}`, expected UInt64, UInt32, Int64, Int32, Datetime, Date, Duration, or Time", data_type, variable_name) + polars_bail!(InvalidOperation: "unsupported data type: {} for `{}`, expected UInt64, UInt32, Int64, Int32, Datetime, Date, Duration, or Time", dtype, variable_name) }, } Ok(()) diff --git a/crates/polars-time/src/windows/group_by.rs b/crates/polars-time/src/windows/group_by.rs index 380a92180322..9ba3a2d3dbc2 100644 --- a/crates/polars-time/src/windows/group_by.rs +++ b/crates/polars-time/src/windows/group_by.rs @@ -557,7 +557,9 @@ pub(crate) fn group_by_values_iter_lookahead_collected( } /// Different from `group_by_windows`, where define window buckets and search which values fit that -/// pre-defined bucket, this function defines every window based on the: +/// pre-defined bucket. +/// +/// This function defines every window based on the: /// - timestamp (lower bound) /// - timestamp + period (upper bound) /// where timestamps are the individual values in the array `time` diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index d8b2d0bc9f73..442d319b7753 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -14,14 +14,16 @@ polars-error = { workspace = true } ahash = { workspace = true } bytemuck = { workspace = true } bytes = { workspace = true } +compact_str = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } +libc = { workspace = true } memmap = { workspace = true, optional = true } num-traits = { workspace = true } once_cell = { workspace = true } raw-cpuid = { workspace = true } rayon = { workspace = true } -smartstring = { workspace = true } +serde = { workspace = true, optional = true } stacker = { workspace = true } sysinfo = { version = "0.31", default-features = false, features = ["system"], optional = true } @@ -35,3 +37,5 @@ version_check = { workspace = true } mmap = ["memmap"] bigidx = [] nightly = [] +ir_serde = ["serde"] +serde = ["dep:serde", "serde/derive"] diff --git a/crates/polars-utils/src/aliases.rs b/crates/polars-utils/src/aliases.rs index 5ecb1b0033d9..1599901677aa 100644 --- a/crates/polars-utils/src/aliases.rs +++ b/crates/polars-utils/src/aliases.rs @@ -1,9 +1,10 @@ -use ahash::RandomState; +pub type PlRandomState = ahash::RandomState; +pub type PlRandomStateQuality = ahash::RandomState; -pub type PlHashMap = hashbrown::HashMap; -pub type PlHashSet = hashbrown::HashSet; -pub type PlIndexMap = indexmap::IndexMap; -pub type PlIndexSet = indexmap::IndexSet; +pub type PlHashMap = hashbrown::HashMap; +pub type PlHashSet = hashbrown::HashSet; +pub type PlIndexMap = indexmap::IndexMap; +pub type PlIndexSet = indexmap::IndexSet; pub trait InitHashMaps { type HashMap; diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index 06741ff454fe..d5748725c4d1 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -1,5 +1,8 @@ use std::sync::atomic::{AtomicU32, Ordering}; +#[cfg(feature = "ir_serde")] +use serde::{Deserialize, Serialize}; + use crate::error::*; use crate::slice::GetSaferUnchecked; @@ -21,6 +24,7 @@ fn index_of(slice: &[T], item: &T) -> Option { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] #[repr(transparent)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub struct Node(pub usize); impl Default for Node { @@ -32,6 +36,7 @@ impl Default for Node { static ARENA_VERSION: AtomicU32 = AtomicU32::new(0); #[derive(Debug, Clone)] +#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))] pub struct Arena { version: u32, items: Vec, diff --git a/crates/polars-utils/src/binary_search.rs b/crates/polars-utils/src/binary_search.rs index b24aa3e33877..5cabb5fab654 100644 --- a/crates/polars-utils/src/binary_search.rs +++ b/crates/polars-utils/src/binary_search.rs @@ -1,3 +1,8 @@ +use std::cmp::Ordering; +use std::cmp::Ordering::{Greater, Less}; + +use crate::slice::GetSaferUnchecked; + /// Find the index of the first element of `arr` that is greater /// or equal to `val`. /// Assumes that `arr` is sorted. @@ -23,3 +28,66 @@ where Err(x) => x, } } + +// https://en.wikipedia.org/wiki/Exponential_search +// Use if you expect matches to be close by. Otherwise use binary search. +pub trait ExponentialSearch { + fn exponential_search_by(&self, f: F) -> Result + where + F: FnMut(&T) -> Ordering; + + fn partition_point_exponential

(&self, mut pred: P) -> usize + where + P: FnMut(&T) -> bool, + { + self.exponential_search_by(|x| if pred(x) { Less } else { Greater }) + .unwrap_or_else(|i| i) + } +} + +impl ExponentialSearch for &[T] { + fn exponential_search_by(&self, mut f: F) -> Result + where + F: FnMut(&T) -> Ordering, + { + if self.is_empty() { + return Err(0); + } + + let mut bound = 1; + + while bound < self.len() { + // SAFETY + // Bound is always >=0 and < len. + let cmp = f(unsafe { self.get_unchecked_release(bound) }); + + if cmp == Greater { + break; + } + bound *= 2 + } + let end_bound = std::cmp::min(self.len(), bound); + // SAFETY: + // We checked the end bound and previous bound was within slice as per the `while` condition. + let prev_bound = bound / 2; + + let slice = unsafe { self.get_unchecked_release(prev_bound..end_bound) }; + + match slice.binary_search_by(f) { + Ok(i) => Ok(i + prev_bound), + Err(i) => Err(i + prev_bound), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_partition_point() { + let v = [1, 2, 3, 3, 5, 6, 7]; + let i = v.as_slice().partition_point_exponential(|&x| x < 5); + assert_eq!(i, 4); + } +} diff --git a/crates/polars-utils/src/cache.rs b/crates/polars-utils/src/cache.rs index 3e9e0eda4d1d..4fcfd870d7e2 100644 --- a/crates/polars-utils/src/cache.rs +++ b/crates/polars-utils/src/cache.rs @@ -3,10 +3,11 @@ use std::cell::Cell; use std::hash::Hash; use std::mem::MaybeUninit; -use ahash::RandomState; use bytemuck::allocation::zeroed_vec; use bytemuck::Zeroable; +use crate::aliases::PlRandomState; + /// A cached function that use `FastFixedCache` for access speed. /// It is important that the key is relatively cheap to compute. pub struct FastCachedFunc { @@ -49,7 +50,7 @@ pub struct FastFixedCache { slots: Vec>, access_ctr: Cell, shift: u32, - hash_builder: RandomState, + random_state: PlRandomState, } impl Default for FastFixedCache { @@ -65,7 +66,7 @@ impl FastFixedCache { slots: zeroed_vec(n), access_ctr: Cell::new(1), shift: 64 - n.ilog2(), - hash_builder: RandomState::new(), + random_state: PlRandomState::default(), } } @@ -212,7 +213,7 @@ impl FastFixedCache { // An instantiation of Dietzfelbinger's multiply-shift, see 2.3 of // https://arxiv.org/pdf/1504.06804.pdf. // The magic constants are just two randomly chosen odd 64-bit numbers. - let h = self.hash_builder.hash_one(key); + let h = self.random_state.hash_one(key); let tag = h as u32; let i1 = (h.wrapping_mul(0x2e623b55bc0c9073) >> self.shift) as usize; let i2 = (h.wrapping_mul(0x921932b06a233d39) >> self.shift) as usize; diff --git a/crates/polars-utils/src/fmt.rs b/crates/polars-utils/src/fmt.rs index dc34490858c0..797c1a45d020 100644 --- a/crates/polars-utils/src/fmt.rs +++ b/crates/polars-utils/src/fmt.rs @@ -1,15 +1,3 @@ -#[macro_export] -macro_rules! format_smartstring { - ($($arg:tt)*) => {{ - use smartstring::alias::String as SmartString; - use std::fmt::Write; - - let mut string = SmartString::new(); - write!(string, $($arg)*).unwrap(); - string - }} -} - #[macro_export] macro_rules! format_list_container { ($e:expr, $start:tt, $end:tt) => {{ diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs index c0f7098a207d..13ecfbb89448 100644 --- a/crates/polars-utils/src/idx_vec.rs +++ b/crates/polars-utils/src/idx_vec.rs @@ -99,8 +99,10 @@ impl UnitVec { } } + /// # Panics + /// Panics if `new_cap <= 1` or `new_cap < self.len` fn realloc(&mut self, new_cap: usize) { - assert!(new_cap >= self.len); + assert!(new_cap > 1 && new_cap >= self.len); unsafe { let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(new_cap)); let buffer = me.as_mut_ptr(); @@ -121,9 +123,17 @@ impl UnitVec { } pub fn with_capacity(capacity: usize) -> Self { - let mut new = Self::new(); - new.reserve(capacity); - new + if capacity <= 1 { + Self::new() + } else { + let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(capacity)); + let data = me.as_mut_ptr(); + Self { + len: 0, + capacity: NonZeroUsize::new(capacity).unwrap(), + data, + } + } } #[inline] @@ -178,13 +188,13 @@ impl Drop for UnitVec { impl Clone for UnitVec { fn clone(&self) -> Self { unsafe { - let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(self.len)); - let buffer = me.as_mut_ptr(); - std::ptr::copy(self.data_ptr(), buffer, self.len); - UnitVec { - data: buffer, - len: self.len, - capacity: NonZeroUsize::new(std::cmp::max(self.len, 1)).unwrap(), + if self.capacity.get() == 1 { + Self { ..*self } + } else { + let mut copy = Self::with_capacity(self.len); + std::ptr::copy(self.data_ptr(), copy.data_ptr_mut(), self.len); + copy.len = self.len; + copy } } } @@ -295,11 +305,57 @@ macro_rules! unitvec { ); ($elem:expr) => ( {let mut new = $crate::idx_vec::UnitVec::new(); + let v = $elem; // SAFETY: first element always fits. - unsafe { new.push_unchecked($elem) }; + unsafe { new.push_unchecked(v) }; new} ); ($($x:expr),+ $(,)?) => ( vec![$($x),+].into() ); } + +mod tests { + + #[test] + #[should_panic] + fn test_unitvec_realloc_zero() { + super::UnitVec::::new().realloc(0); + } + + #[test] + #[should_panic] + fn test_unitvec_realloc_one() { + super::UnitVec::::new().realloc(1); + } + + #[test] + #[should_panic] + fn test_untivec_realloc_lt_len() { + super::UnitVec::::from(&[1, 2][..]).realloc(1) + } + + #[test] + fn test_unitvec_clone() { + { + let v = unitvec![1usize]; + assert_eq!(v, v.clone()); + } + + for n in [ + 26903816120209729usize, + 42566276440897687, + 44435161834424652, + 49390731489933083, + 51201454727649242, + 83861672190814841, + 92169290527847622, + 92476373900398436, + 95488551309275459, + 97499984126814549, + ] { + let v = unitvec![n]; + assert_eq!(v, v.clone()); + } + } +} diff --git a/crates/polars-utils/src/iter/fallible.rs b/crates/polars-utils/src/iter/fallible.rs deleted file mode 100644 index 7ba544b13e18..000000000000 --- a/crates/polars-utils/src/iter/fallible.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::error::Error; - -pub trait FallibleIterator: Iterator { - fn get_result(&mut self) -> Result<(), E>; -} - -pub trait FromFallibleIterator: Sized { - fn from_fallible_iter>(iter: F) -> Result; -} - -impl, E: Error> FromFallibleIterator for T { - fn from_fallible_iter>(mut iter: F) -> Result { - let out = T::from_iter(&mut iter); - iter.get_result()?; - Ok(out) - } -} diff --git a/crates/polars-utils/src/iter/mod.rs b/crates/polars-utils/src/iter/mod.rs deleted file mode 100644 index b3158416abb2..000000000000 --- a/crates/polars-utils/src/iter/mod.rs +++ /dev/null @@ -1,24 +0,0 @@ -mod enumerate_idx; -mod fallible; - -pub use enumerate_idx::EnumerateIdxTrait; -pub use fallible::*; - -pub trait IntoIteratorCopied: IntoIterator { - /// The type of the elements being iterated over. - type OwnedItem; - - /// Which kind of iterator are we turning this into? - type IntoIterCopied: Iterator::OwnedItem>; - - fn into_iter(self) -> ::IntoIterCopied; -} - -impl<'a, T: Copy> IntoIteratorCopied for &'a [T] { - type OwnedItem = T; - type IntoIterCopied = std::iter::Copied>; - - fn into_iter(self) -> ::IntoIterCopied { - self.iter().copied() - } -} diff --git a/crates/polars-utils/src/iter/enumerate_idx.rs b/crates/polars-utils/src/itertools/enumerate_idx.rs similarity index 84% rename from crates/polars-utils/src/iter/enumerate_idx.rs rename to crates/polars-utils/src/itertools/enumerate_idx.rs index 8b17b7ef4038..ef079fe33011 100644 --- a/crates/polars-utils/src/iter/enumerate_idx.rs +++ b/crates/polars-utils/src/itertools/enumerate_idx.rs @@ -1,6 +1,4 @@ -use num_traits::{FromPrimitive, One}; - -use crate::IdxSize; +use num_traits::{FromPrimitive, One, Zero}; /// An iterator that yields the current count and the element during iteration. /// @@ -16,6 +14,15 @@ pub struct EnumerateIdx { count: IdxType, } +impl EnumerateIdx { + pub fn new(iter: I) -> Self { + Self { + iter, + count: IdxType::zero(), + } + } +} + impl Iterator for EnumerateIdx where I: Iterator, @@ -92,27 +99,3 @@ where self.iter.len() } } - -pub trait EnumerateIdxTrait: Iterator { - fn enumerate_idx(self) -> EnumerateIdx - where - Self: Sized, - { - EnumerateIdx { - iter: self, - count: 0, - } - } - - fn enumerate_u32(self) -> EnumerateIdx - where - Self: Sized, - { - EnumerateIdx { - iter: self, - count: 0, - } - } -} - -impl EnumerateIdxTrait for T where T: Iterator {} diff --git a/crates/polars-utils/src/itertools/mod.rs b/crates/polars-utils/src/itertools/mod.rs new file mode 100644 index 000000000000..7c755444be1f --- /dev/null +++ b/crates/polars-utils/src/itertools/mod.rs @@ -0,0 +1,106 @@ +use std::cmp::Ordering; + +use crate::IdxSize; + +pub mod enumerate_idx; + +/// Utility extension trait of iterator methods. +pub trait Itertools: Iterator { + /// Equivalent to `.collect::>()`. + fn collect_vec(self) -> Vec + where + Self: Sized, + { + self.collect() + } + + /// Equivalent to `.collect::>()`. + fn try_collect(self) -> Result + where + Self: Sized + Iterator>, + Result: FromIterator>, + { + self.collect() + } + + /// Equivalent to `.collect::, _>>()`. + fn try_collect_vec(self) -> Result, E> + where + Self: Sized + Iterator>, + Result, E>: FromIterator>, + { + self.collect() + } + + fn enumerate_idx(self) -> enumerate_idx::EnumerateIdx + where + Self: Sized, + { + enumerate_idx::EnumerateIdx::new(self) + } + + fn enumerate_u32(self) -> enumerate_idx::EnumerateIdx + where + Self: Sized, + { + enumerate_idx::EnumerateIdx::new(self) + } + + fn all_equal(mut self) -> bool + where + Self: Sized, + Self::Item: PartialEq, + { + match self.next() { + None => true, + Some(a) => self.all(|x| a == x), + } + } + + // Stable copy of the unstable eq_by from the stdlib. + fn eq_by_(mut self, other: I, mut eq: F) -> bool + where + Self: Sized, + I: IntoIterator, + F: FnMut(Self::Item, I::Item) -> bool, + { + let mut other = other.into_iter(); + loop { + match (self.next(), other.next()) { + (None, None) => return true, + (None, Some(_)) => return false, + (Some(_), None) => return false, + (Some(l), Some(r)) => { + if eq(l, r) { + continue; + } else { + return false; + } + }, + } + } + } + + // Stable copy of the unstable partial_cmp_by from the stdlib. + fn partial_cmp_by_(mut self, other: I, mut partial_cmp: F) -> Option + where + Self: Sized, + I: IntoIterator, + F: FnMut(Self::Item, I::Item) -> Option, + { + let mut other = other.into_iter(); + loop { + match (self.next(), other.next()) { + (None, None) => return Some(Ordering::Equal), + (None, Some(_)) => return Some(Ordering::Less), + (Some(_), None) => return Some(Ordering::Greater), + (Some(l), Some(r)) => match partial_cmp(l, r) { + Some(Ordering::Equal) => continue, + ord => return ord, + }, + } + } + } +} + +impl Itertools for T {} diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 09360d90d2ce..68e331973800 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -20,6 +20,7 @@ pub mod hashing; pub mod idx_vec; pub mod mem; pub mod min_max; +pub mod pl_str; pub mod priority; pub mod slice; pub mod sort; @@ -34,7 +35,7 @@ pub use functions::*; pub mod aliases; pub mod fixedringbuffer; pub mod fmt; -pub mod iter; +pub mod itertools; pub mod macros; pub mod vec; #[cfg(target_family = "wasm")] diff --git a/crates/polars-utils/src/mem.rs b/crates/polars-utils/src/mem.rs index d4f4e3d028fd..4fee5b842f63 100644 --- a/crates/polars-utils/src/mem.rs +++ b/crates/polars-utils/src/mem.rs @@ -1,3 +1,15 @@ +use once_cell::sync::Lazy; +static PAGE_SIZE: Lazy = Lazy::new(|| { + #[cfg(target_family = "unix")] + unsafe { + libc::sysconf(libc::_SC_PAGESIZE) as usize + } + #[cfg(not(target_family = "unix"))] + { + 4096 + } +}); + /// # Safety /// This may break aliasing rules, make sure you are the only owner. #[allow(clippy::mut_from_ref)] @@ -10,7 +22,7 @@ pub unsafe fn to_mutable_slice(s: &[T]) -> &mut [T] { /// # Safety /// /// This should only be called with pointers to valid memory. -pub unsafe fn prefetch_l2(ptr: *const u8) { +unsafe fn prefetch_l2_impl(ptr: *const u8) { #[cfg(target_arch = "x86_64")] { use std::arch::x86_64::*; @@ -23,3 +35,54 @@ pub unsafe fn prefetch_l2(ptr: *const u8) { unsafe { _prefetch(ptr as *const _, _PREFETCH_READ, _PREFETCH_LOCALITY2) }; } } + +/// Attempt to prefetch the memory in the slice to the L2 cache. +pub fn prefetch_l2(slice: &[u8]) { + if slice.is_empty() { + return; + } + + // @TODO: We can play a bit more with this prefetching. Maybe introduce a maximum number of + // prefetches as to not overwhelm the processor. The linear prefetcher should pick it up + // at a certain point. + + for i in (0..slice.len()).step_by(*PAGE_SIZE) { + unsafe { prefetch_l2_impl(slice[i..].as_ptr()) }; + } + + unsafe { prefetch_l2_impl(slice[slice.len() - 1..].as_ptr()) } +} + +/// `madvise()` with `MADV_SEQUENTIAL` on unix systems. This is a no-op on non-unix systems. +pub fn madvise_sequential(#[allow(unused)] slice: &[u8]) { + #[cfg(target_family = "unix")] + madvise(slice, libc::MADV_SEQUENTIAL); +} + +/// `madvise()` with `MADV_WILLNEED` on unix systems. This is a no-op on non-unix systems. +pub fn madvise_willneed(#[allow(unused)] slice: &[u8]) { + #[cfg(target_family = "unix")] + madvise(slice, libc::MADV_WILLNEED); +} + +/// `madvise()` with `MADV_POPULATE_READ` on linux systems. This a no-op on non-linux systems. +pub fn madvise_populate_read(#[allow(unused)] slice: &[u8]) { + #[cfg(target_os = "linux")] + madvise(slice, libc::MADV_POPULATE_READ); +} + +#[cfg(target_family = "unix")] +fn madvise(slice: &[u8], advice: libc::c_int) { + let ptr = slice.as_ptr(); + + let align = ptr as usize % *PAGE_SIZE; + let ptr = ptr.wrapping_sub(align); + let len = slice.len() + align; + + if unsafe { libc::madvise(ptr as *mut libc::c_void, len, advice) } != 0 { + let err = std::io::Error::last_os_error(); + if let std::io::ErrorKind::InvalidInput = err.kind() { + panic!("{}", err); + } + } +} diff --git a/crates/polars-utils/src/mmap.rs b/crates/polars-utils/src/mmap.rs index 5bd8e2df12a5..9e946b3dac52 100644 --- a/crates/polars-utils/src/mmap.rs +++ b/crates/polars-utils/src/mmap.rs @@ -1,14 +1,16 @@ +use std::fs::File; use std::io; -use std::sync::Arc; pub use memmap::Mmap; mod private { + use std::fs::File; use std::ops::Deref; use std::sync::Arc; - pub use memmap::Mmap; + use polars_error::PolarsResult; + use super::MMapSemaphore; use crate::mem::prefetch_l2; /// A read-only reference to a slice of memory that can potentially be memory-mapped. @@ -34,7 +36,7 @@ mod private { #[allow(unused)] enum MemSliceInner { Bytes(bytes::Bytes), - Mmap(Arc), + Mmap(Arc), } impl Deref for MemSlice { @@ -46,6 +48,13 @@ mod private { } } + impl AsRef<[u8]> for MemSlice { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.slice + } + } + impl Default for MemSlice { fn default() -> Self { Self::from_bytes(bytes::Bytes::new()) @@ -75,7 +84,7 @@ mod private { } #[inline] - pub fn from_mmap(mmap: Arc) -> Self { + pub fn from_mmap(mmap: Arc) -> Self { Self { slice: unsafe { std::mem::transmute::<&[u8], &'static [u8]>(mmap.as_ref().as_ref()) @@ -84,6 +93,12 @@ mod private { } } + #[inline] + pub fn from_file(file: &File) -> PolarsResult { + let mmap = MMapSemaphore::new_from_file(file)?; + Ok(Self::from_mmap(Arc::new(mmap))) + } + /// Construct a `MemSlice` that simply wraps around a `&[u8]`. #[inline] pub fn from_slice(slice: &'static [u8]) -> Self { @@ -93,19 +108,7 @@ mod private { /// Attempt to prefetch the memory belonging to to this [`MemSlice`] #[inline] pub fn prefetch(&self) { - if self.len() == 0 { - return; - } - - // @TODO: We can play a bit more with this prefetching. Maybe introduce a maximum number of - // prefetches as to not overwhelm the processor. The linear prefetcher should pick it up - // at a certain point. - - const PAGE_SIZE: usize = 4096; - for i in 0..self.len() / PAGE_SIZE { - unsafe { prefetch_l2(self[i * PAGE_SIZE..].as_ptr()) }; - } - unsafe { prefetch_l2(self[self.len() - 1..].as_ptr()) } + prefetch_l2(self.as_ref()); } /// # Panics @@ -120,6 +123,8 @@ mod private { } } +use memmap::MmapOptions; +use polars_error::{polars_bail, PolarsResult}; pub use private::MemSlice; /// A cursor over a [`MemSlice`]. @@ -161,11 +166,6 @@ impl MemReader { Self::new(MemSlice::from_bytes(bytes)) } - #[inline(always)] - pub fn from_mmap(mmap: Arc) -> Self { - Self::new(MemSlice::from_mmap(mmap)) - } - // Construct a `MemSlice` that simply wraps around a `&[u8]`. The caller must ensure the /// slice outlives the returned `MemSlice`. #[inline] @@ -236,8 +236,91 @@ impl io::Seek for MemReader { } } -mod tests { +// Keep track of memory mapped files so we don't write to them while reading +// Use a btree as it uses less memory than a hashmap and this thing never shrinks. +// Write handle in Windows is exclusive, so this is only necessary in Unix. +#[cfg(target_family = "unix")] +static MEMORY_MAPPED_FILES: once_cell::sync::Lazy< + std::sync::Mutex>, +> = once_cell::sync::Lazy::new(|| std::sync::Mutex::new(Default::default())); + +#[derive(Debug)] +pub struct MMapSemaphore { + #[cfg(target_family = "unix")] + key: (u64, u64), + mmap: Mmap, +} + +impl MMapSemaphore { + pub fn new_from_file_with_options( + file: &File, + options: MmapOptions, + ) -> PolarsResult { + let mmap = unsafe { options.map(file) }?; + #[cfg(target_family = "unix")] + { + use std::os::unix::fs::MetadataExt; + let metadata = file.metadata()?; + + let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); + let key = (metadata.dev(), metadata.ino()); + match guard.entry(key) { + std::collections::btree_map::Entry::Occupied(mut e) => *e.get_mut() += 1, + std::collections::btree_map::Entry::Vacant(e) => _ = e.insert(1), + } + Ok(Self { key, mmap }) + } + + #[cfg(not(target_family = "unix"))] + Ok(Self { mmap }) + } + + pub fn new_from_file(file: &File) -> PolarsResult { + Self::new_from_file_with_options(file, MmapOptions::default()) + } + + pub fn as_ptr(&self) -> *const u8 { + self.mmap.as_ptr() + } +} + +impl AsRef<[u8]> for MMapSemaphore { + #[inline] + fn as_ref(&self) -> &[u8] { + self.mmap.as_ref() + } +} + +#[cfg(target_family = "unix")] +impl Drop for MMapSemaphore { + fn drop(&mut self) { + let mut guard = MEMORY_MAPPED_FILES.lock().unwrap(); + if let std::collections::btree_map::Entry::Occupied(mut e) = guard.entry(self.key) { + let v = e.get_mut(); + *v -= 1; + + if *v == 0 { + e.remove_entry(); + } + } + } +} + +pub fn ensure_not_mapped(#[allow(unused)] file: &File) -> PolarsResult<()> { + #[cfg(target_family = "unix")] + { + use std::os::unix::fs::MetadataExt; + let guard = MEMORY_MAPPED_FILES.lock().unwrap(); + let metadata = file.metadata()?; + if guard.contains_key(&(metadata.dev(), metadata.ino())) { + polars_bail!(ComputeError: "cannot write to file: already memory mapped"); + } + } + Ok(()) +} + +mod tests { #[test] fn test_mem_slice_zero_copy() { use std::sync::Arc; @@ -276,9 +359,11 @@ mod tests { } { + use crate::mmap::MMapSemaphore; + let path = "../../examples/datasets/foods1.csv"; let file = std::fs::File::open(path).unwrap(); - let mmap = unsafe { memmap::Mmap::map(&file) }.unwrap(); + let mmap = MMapSemaphore::new_from_file(&file).unwrap(); let ptr = mmap.as_ptr(); let mem_slice = MemSlice::from_mmap(Arc::new(mmap)); diff --git a/crates/polars-utils/src/partitioned.rs b/crates/polars-utils/src/partitioned.rs index c23af9e95d76..f9f9eb563ae8 100644 --- a/crates/polars-utils/src/partitioned.rs +++ b/crates/polars-utils/src/partitioned.rs @@ -1,9 +1,10 @@ use hashbrown::hash_map::{HashMap, RawEntryBuilder, RawEntryBuilderMut}; +use crate::aliases::PlRandomState; use crate::hashing::hash_to_partition; use crate::slice::GetSaferUnchecked; -pub struct PartitionedHashMap { +pub struct PartitionedHashMap { inner: Vec>, } diff --git a/crates/polars-utils/src/pl_str.rs b/crates/polars-utils/src/pl_str.rs new file mode 100644 index 000000000000..72beb122b233 --- /dev/null +++ b/crates/polars-utils/src/pl_str.rs @@ -0,0 +1,252 @@ +#[macro_export] +macro_rules! format_pl_smallstr { + ($($arg:tt)*) => {{ + use std::fmt::Write; + + let mut string = PlSmallStr::EMPTY; + write!(string, $($arg)*).unwrap(); + string + }} +} + +type Inner = compact_str::CompactString; + +/// String type that inlines small strings. +#[derive(Clone, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct PlSmallStr(Inner); + +impl PlSmallStr { + pub const EMPTY: Self = Self::from_static(""); + pub const EMPTY_REF: &'static Self = &Self::from_static(""); + + #[inline(always)] + pub const fn from_static(s: &'static str) -> Self { + Self(Inner::const_new(s)) + } + + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn from_str(s: &str) -> Self { + Self(Inner::from(s)) + } + + #[inline(always)] + pub fn from_string(s: String) -> Self { + Self(Inner::from(s)) + } + + #[inline(always)] + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + #[inline(always)] + pub fn into_string(self) -> String { + self.0.into_string() + } +} + +impl Default for PlSmallStr { + #[inline(always)] + fn default() -> Self { + Self::EMPTY + } +} + +/// AsRef, Deref and Borrow impls to &str + +impl AsRef for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl core::ops::Deref for PlSmallStr { + type Target = str; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} + +impl core::borrow::Borrow for PlSmallStr { + #[inline(always)] + fn borrow(&self) -> &str { + self.as_str() + } +} + +/// AsRef impls for other types + +impl AsRef for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &std::path::Path { + self.as_str().as_ref() + } +} + +impl AsRef<[u8]> for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.as_str().as_bytes() + } +} + +impl AsRef for PlSmallStr { + #[inline(always)] + fn as_ref(&self) -> &std::ffi::OsStr { + self.as_str().as_ref() + } +} + +/// From impls + +impl From<&str> for PlSmallStr { + #[inline(always)] + fn from(value: &str) -> Self { + Self::from_str(value) + } +} + +impl From for PlSmallStr { + #[inline(always)] + fn from(value: String) -> Self { + Self::from_string(value) + } +} + +impl From<&String> for PlSmallStr { + #[inline(always)] + fn from(value: &String) -> Self { + Self::from_str(value.as_str()) + } +} + +impl From for PlSmallStr { + #[inline(always)] + fn from(value: Inner) -> Self { + Self(value) + } +} + +/// FromIterator impls + +impl FromIterator for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: T) -> Self { + Self(Inner::from_iter(iter.into_iter().map(|x| x.0))) + } +} + +impl<'a> FromIterator<&'a PlSmallStr> for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: T) -> Self { + Self(Inner::from_iter(iter.into_iter().map(|x| x.as_str()))) + } +} + +impl FromIterator for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl<'a> FromIterator<&'a char> for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl<'a> FromIterator<&'a str> for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl FromIterator for PlSmallStr { + #[inline(always)] + fn from_iter>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl FromIterator> for PlSmallStr { + #[inline(always)] + fn from_iter>>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +impl<'a> FromIterator> for PlSmallStr { + #[inline(always)] + fn from_iter>>(iter: I) -> PlSmallStr { + Self(Inner::from_iter(iter)) + } +} + +/// PartialEq impls + +impl PartialEq for PlSmallStr +where + T: AsRef + ?Sized, +{ + #[inline(always)] + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() + } +} + +impl PartialEq for &str { + #[inline(always)] + fn eq(&self, other: &PlSmallStr) -> bool { + *self == other.as_str() + } +} + +impl PartialEq for String { + #[inline(always)] + fn eq(&self, other: &PlSmallStr) -> bool { + self.as_str() == other.as_str() + } +} + +/// Write + +impl core::fmt::Write for PlSmallStr { + #[inline(always)] + fn write_char(&mut self, c: char) -> std::fmt::Result { + self.0.write_char(c) + } + + #[inline(always)] + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::fmt::Result { + self.0.write_fmt(args) + } + + #[inline(always)] + fn write_str(&mut self, s: &str) -> std::fmt::Result { + self.0.write_str(s) + } +} + +/// Debug, Display + +impl core::fmt::Debug for PlSmallStr { + #[inline(always)] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_str().fmt(f) + } +} + +impl core::fmt::Display for PlSmallStr { + #[inline(always)] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_str().fmt(f) + } +} diff --git a/crates/polars-utils/src/sync.rs b/crates/polars-utils/src/sync.rs index 895b1290be31..31151fb32518 100644 --- a/crates/polars-utils/src/sync.rs +++ b/crates/polars-utils/src/sync.rs @@ -1,6 +1,6 @@ /// Utility that allows use to send pointers to another thread. /// This is better than going through `usize` as MIRI can follow these. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] #[repr(transparent)] pub struct SyncPtr(*mut T); diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 97bb99b279cf..a27907484369 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -25,7 +25,7 @@ polars-utils = { workspace = true } [dev-dependencies] ahash = { workspace = true } -apache-avro = { version = "0.16", features = ["snappy"] } +apache-avro = { version = "0.17", features = ["snappy"] } arrow = { workspace = true, features = ["arrow_rs"] } arrow-buffer = { workspace = true } avro-schema = { workspace = true, features = ["async"] } @@ -65,13 +65,14 @@ default = [ ] ndarray = ["polars-core/ndarray"] # serde support for dataframes and series -serde = ["polars-core/serde"] +serde = ["polars-core/serde", "polars-utils/serde"] serde-lazy = [ "polars-core/serde-lazy", "polars-lazy?/serde", "polars-time?/serde", "polars-io?/serde", "polars-ops?/serde", + "polars-utils/serde", ] parquet = ["polars-io", "polars-lazy?/parquet", "polars-io/parquet", "polars-sql?/parquet"] async = ["polars-lazy?/async"] @@ -190,7 +191,7 @@ moment = ["polars-ops/moment", "polars-lazy?/moment"] partition_by = ["polars-core/partition_by"] pct_change = ["polars-ops/pct_change", "polars-lazy?/pct_change"] peaks = ["polars-lazy/peaks"] -pivot = ["polars-lazy?/pivot"] +pivot = ["polars-lazy?/pivot", "polars-ops/pivot", "dtype-struct", "rows"] product = ["polars-core/product"] propagate_nans = ["polars-lazy?/propagate_nans"] range = ["polars-lazy?/range"] @@ -228,6 +229,7 @@ zip_with = ["polars-core/zip_with"] bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] polars_cloud = ["polars-lazy?/polars_cloud"] +ir_serde = ["polars-plan/ir_serde"] test = [ "lazy", diff --git a/crates/polars/src/docs/eager.rs b/crates/polars/src/docs/eager.rs index 8faf05c0a96c..95c759f836e7 100644 --- a/crates/polars/src/docs/eager.rs +++ b/crates/polars/src/docs/eager.rs @@ -48,10 +48,10 @@ //! let ca: UInt32Chunked = (0..10).map(Some).collect(); //! //! // from slices -//! let ca = UInt32Chunked::new("foo", &[1, 2, 3]); +//! let ca = UInt32Chunked::new("foo".into(), &[1, 2, 3]); //! //! // use builders -//! let mut builder = PrimitiveChunkedBuilder::::new("foo", 10); +//! let mut builder = PrimitiveChunkedBuilder::::new("foo".into(), 10); //! for value in 0..10 { //! builder.append_value(value); //! } @@ -67,10 +67,10 @@ //! let s: Series = (0..10).map(Some).collect(); //! //! // from slices -//! let s = Series::new("foo", &[1, 2, 3]); +//! let s = Series::new("foo".into(), &[1, 2, 3]); //! //! // from a chunked-array -//! let ca = UInt32Chunked::new("foo", &[Some(1), None, Some(3)]); +//! let ca = UInt32Chunked::new("foo".into(), &[Some(1), None, Some(3)]); //! let s = ca.into_series(); //! ``` //! @@ -89,8 +89,8 @@ //! ]?; //! //! // from a Vec -//! let s1 = Series::new("names", &["a", "b", "c"]); -//! let s2 = Series::new("values", &[Some(1), None, Some(3)]); +//! let s1 = Series::new("names".into(), &["a", "b", "c"]); +//! let s2 = Series::new("values".into(), &[Some(1), None, Some(3)]); //! let df = DataFrame::new(vec![s1, s2])?; //! # Ok(()) //! # } @@ -103,8 +103,8 @@ //! ``` //! use polars::prelude::*; //! # fn example() -> PolarsResult<()> { -//! let s_int = Series::new("a", &[1, 2, 3]); -//! let s_flt = Series::new("b", &[1.0, 2.0, 3.0]); +//! let s_int = Series::new("a".into(), &[1, 2, 3]); +//! let s_flt = Series::new("b".into(), &[1.0, 2.0, 3.0]); //! //! let added = &s_int + &s_flt; //! let subtracted = &s_int - &s_flt; @@ -125,7 +125,7 @@ //! let multiplied = s_flt * 2.0; //! //! // or broadcast Series to match the operands type -//! let added = &s_int * &Series::new("broadcast_me", &[10]); +//! let added = &s_int * &Series::new("broadcast_me".into(), &[10]); //! //! # Ok(()) //! # } @@ -136,7 +136,7 @@ //! //! ```rust //! # use polars::prelude::*; -//! let series = Series::new("foo", [1, 2, 3]); +//! let series = Series::new("foo".into(), [1, 2, 3]); //! //! // 1 / s //! let divide_one_by_s = 1.div(&series); @@ -151,7 +151,7 @@ //! //! ```rust //! # use polars::prelude::*; -//! let ca = UInt32Chunked::new("foo", &[1, 2, 3]); +//! let ca = UInt32Chunked::new("foo".into(), &[1, 2, 3]); //! //! // 1 / ca //! let divide_one_by_ca = ca.apply_values(|rhs| 1 / rhs); @@ -165,8 +165,8 @@ //! use polars::prelude::*; //! # fn example() -> PolarsResult<()> { //! -//! let s = Series::new("a", &[1, 2, 3]); -//! let ca = UInt32Chunked::new("b", &[Some(3), None, Some(1)]); +//! let s = Series::new("a".into(), &[1, 2, 3]); +//! let ca = UInt32Chunked::new("b".into(), &[Some(3), None, Some(1)]); //! //! // compare Series with numeric values //! // == @@ -251,11 +251,11 @@ //! # fn example() -> PolarsResult<()> { //! //! // apply a closure over all values -//! let s = Series::new("foo", &[Some(1), Some(2), None]); +//! let s = Series::new("foo".into(), &[Some(1), Some(2), None]); //! s.i32()?.apply_values(|value| value * 20); //! //! // count string lengths -//! let s = Series::new("foo", &["foo", "bar", "foobar"]); +//! let s = Series::new("foo".into(), &["foo", "bar", "foobar"]); //! unary_elementwise_values(s.str()?, |str_val| str_val.len() as u64); //! //! # Ok(()) @@ -506,15 +506,15 @@ //! use polars::df; //! //! # fn example(df: &DataFrame) -> PolarsResult<()> { -//! let s0 = Series::new("a", &[1i64, 2, 3]); -//! let s1 = Series::new("b", &[1i64, 1, 1]); -//! let s2 = Series::new("c", &[2i64, 2, 2]); +//! let s0 = Series::new("a".into(), &[1i64, 2, 3]); +//! let s1 = Series::new("b".into(), &[1i64, 1, 1]); +//! let s2 = Series::new("c".into(), &[2i64, 2, 2]); //! // construct a new ListChunked for a slice of Series. //! let list = Series::new("foo", &[s0, s1, s2]); //! //! // construct a few more Series. -//! let s0 = Series::new("B", [1, 2, 3]); -//! let s1 = Series::new("C", [1, 1, 1]); +//! let s0 = Series::new("B".into(), [1, 2, 3]); +//! let s1 = Series::new("C".into(), [1, 1, 1]); //! let df = DataFrame::new(vec![list, s0, s1])?; //! //! let exploded = df.explode(["foo"])?; diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 00086736c6e5..9910df124fa5 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -405,6 +405,7 @@ //! `T` in complex lazy expressions. However this does require `unsafe` code allow this. //! * `POLARS_NO_PARQUET_STATISTICS` -> if set, statistics in parquet files are ignored. //! * `POLARS_PANIC_ON_ERR` -> panic instead of returning an Error. +//! * `POLARS_BACKTRACE_IN_ERR` -> include a Rust backtrace in Error messages. //! * `POLARS_NO_CHUNKED_JOIN` -> force rechunk before joins. //! //! ## User guide diff --git a/crates/polars/tests/it/arrow/array/binary/mod.rs b/crates/polars/tests/it/arrow/array/binary/mod.rs index 5446f2eb83f1..5c06c8e7680c 100644 --- a/crates/polars/tests/it/arrow/array/binary/mod.rs +++ b/crates/polars/tests/it/arrow/array/binary/mod.rs @@ -126,7 +126,7 @@ fn wrong_offsets() { #[test] #[should_panic] -fn wrong_data_type() { +fn wrong_dtype() { let offsets = vec![0, 4].try_into().unwrap(); let values = Buffer::from(b"abbb".to_vec()); BinaryArray::::new(ArrowDataType::Int8, offsets, values, None); diff --git a/crates/polars/tests/it/arrow/array/binary/mutable_values.rs b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs index c9e4f1da3bbe..8d9500f2911d 100644 --- a/crates/polars/tests/it/arrow/array/binary/mutable_values.rs +++ b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs @@ -21,7 +21,7 @@ fn offsets_must_be_in_bounds() { } #[test] -fn data_type_must_be_consistent() { +fn dtype_must_be_consistent() { let offsets = vec![0, 4].try_into().unwrap(); let values = b"abbb".to_vec(); assert!( diff --git a/crates/polars/tests/it/arrow/array/boolean/mod.rs b/crates/polars/tests/it/arrow/array/boolean/mod.rs index 5d96140bed75..2c0e3c4d6d76 100644 --- a/crates/polars/tests/it/arrow/array/boolean/mod.rs +++ b/crates/polars/tests/it/arrow/array/boolean/mod.rs @@ -13,7 +13,7 @@ fn array() -> BooleanArray { fn basics() { let array = array(); - assert_eq!(array.data_type(), &ArrowDataType::Boolean); + assert_eq!(array.dtype(), &ArrowDataType::Boolean); assert!(array.value(0)); assert!(!array.value(1)); diff --git a/crates/polars/tests/it/arrow/array/dictionary/mod.rs b/crates/polars/tests/it/arrow/array/dictionary/mod.rs index 924bbcab45a4..bb05634c7683 100644 --- a/crates/polars/tests/it/arrow/array/dictionary/mod.rs +++ b/crates/polars/tests/it/arrow/array/dictionary/mod.rs @@ -6,14 +6,10 @@ use arrow::datatypes::ArrowDataType; #[test] fn try_new_ok() { let values = Utf8Array::::from_slice(["a", "aa"]); - let data_type = - ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.data_type().clone()), false); - let array = DictionaryArray::try_new( - data_type, - PrimitiveArray::from_vec(vec![1, 0]), - values.boxed(), - ) - .unwrap(); + let dtype = ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.dtype().clone()), false); + let array = + DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); assert_eq!(array.keys(), &PrimitiveArray::from_vec(vec![1i32, 0])); assert_eq!( @@ -28,14 +24,10 @@ fn try_new_ok() { #[test] fn split_at() { let values = Utf8Array::::from_slice(["a", "aa"]); - let data_type = - ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.data_type().clone()), false); - let array = DictionaryArray::try_new( - data_type, - PrimitiveArray::from_vec(vec![1, 0]), - values.boxed(), - ) - .unwrap(); + let dtype = ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.dtype().clone()), false); + let array = + DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); let (lhs, rhs) = array.split_at(1); @@ -46,15 +38,10 @@ fn split_at() { #[test] fn try_new_incorrect_key() { let values = Utf8Array::::from_slice(["a", "aa"]); - let data_type = - ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(values.data_type().clone()), false); + let dtype = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(values.dtype().clone()), false); - let r = DictionaryArray::try_new( - data_type, - PrimitiveArray::from_vec(vec![1, 0]), - values.boxed(), - ) - .is_err(); + let r = DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .is_err(); assert!(r); } @@ -66,9 +53,8 @@ fn try_new_nulls() { let value: &[&str] = &[]; let values = Utf8Array::::from_slice(value); - let data_type = - ArrowDataType::Dictionary(u32::KEY_TYPE, Box::new(values.data_type().clone()), false); - let r = DictionaryArray::try_new(data_type, keys, values.boxed()).is_ok(); + let dtype = ArrowDataType::Dictionary(u32::KEY_TYPE, Box::new(values.dtype().clone()), false); + let r = DictionaryArray::try_new(dtype, keys, values.boxed()).is_ok(); assert!(r); } @@ -76,14 +62,10 @@ fn try_new_nulls() { #[test] fn try_new_incorrect_dt() { let values = Utf8Array::::from_slice(["a", "aa"]); - let data_type = ArrowDataType::Int32; + let dtype = ArrowDataType::Int32; - let r = DictionaryArray::try_new( - data_type, - PrimitiveArray::from_vec(vec![1, 0]), - values.boxed(), - ) - .is_err(); + let r = DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .is_err(); assert!(r); } @@ -91,15 +73,10 @@ fn try_new_incorrect_dt() { #[test] fn try_new_incorrect_values_dt() { let values = Utf8Array::::from_slice(["a", "aa"]); - let data_type = - ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(ArrowDataType::LargeUtf8), false); - - let r = DictionaryArray::try_new( - data_type, - PrimitiveArray::from_vec(vec![1, 0]), - values.boxed(), - ) - .is_err(); + let dtype = ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(ArrowDataType::LargeUtf8), false); + + let r = DictionaryArray::try_new(dtype, PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .is_err(); assert!(r); } diff --git a/crates/polars/tests/it/arrow/array/equal/list.rs b/crates/polars/tests/it/arrow/array/equal/list.rs index 34370ad5459e..6deec984f7f6 100644 --- a/crates/polars/tests/it/arrow/array/equal/list.rs +++ b/crates/polars/tests/it/arrow/array/equal/list.rs @@ -67,7 +67,7 @@ fn test_list_offsets() { #[test] fn test_bla() { let offsets = vec![0, 3, 3, 6].try_into().unwrap(); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let values = Box::new(Int32Array::from([ Some(1), Some(2), @@ -77,14 +77,14 @@ fn test_bla() { Some(6), ])); let validity = Bitmap::from([true, false, true]); - let lhs = ListArray::::new(data_type, offsets, values, Some(validity)); + let lhs = ListArray::::new(dtype, offsets, values, Some(validity)); let lhs = lhs.sliced(1, 2); let offsets = vec![0, 0, 3].try_into().unwrap(); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let values = Box::new(Int32Array::from([Some(4), None, Some(6)])); let validity = Bitmap::from([false, true]); - let rhs = ListArray::::new(data_type, offsets, values, Some(validity)); + let rhs = ListArray::::new(dtype, offsets, values, Some(validity)); assert_eq!(lhs, rhs); } diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs index 12019be64205..141f5e8504c8 100644 --- a/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs @@ -84,7 +84,7 @@ fn wrong_len() { } #[test] -fn wrong_data_type() { +fn wrong_dtype() { let values = Buffer::from(b"abba".to_vec()); assert!(FixedSizeBinaryArray::try_new(ArrowDataType::Binary, values, None).is_err()); } @@ -95,7 +95,7 @@ fn to() { let a = FixedSizeBinaryArray::new(ArrowDataType::FixedSizeBinary(2), values, None); let extension = ArrowDataType::Extension( - "a".to_string(), + "a".into(), Box::new(ArrowDataType::FixedSizeBinary(2)), None, ); diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs index 316157087fbb..ad89efd256b1 100644 --- a/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs @@ -11,7 +11,7 @@ fn basic() { ) .unwrap(); assert_eq!(a.len(), 2); - assert_eq!(a.data_type(), &ArrowDataType::FixedSizeBinary(2)); + assert_eq!(a.dtype(), &ArrowDataType::FixedSizeBinary(2)); assert_eq!(a.values(), &Vec::from([1, 2, 3, 4])); assert_eq!(a.validity(), None); assert_eq!(a.value(1), &[3, 4]); diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs index 5b8ea7b8c950..5e3e4174f667 100644 --- a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs @@ -9,7 +9,7 @@ fn data() -> FixedSizeListArray { FixedSizeListArray::try_new( ArrowDataType::FixedSizeList( - Box::new(Field::new("a", values.data_type().clone(), true)), + Box::new(Field::new("a".into(), values.dtype().clone(), true)), 2, ), values.boxed(), @@ -59,7 +59,7 @@ fn debug() { #[test] fn empty() { let array = FixedSizeListArray::new_empty(ArrowDataType::FixedSizeList( - Box::new(Field::new("a", ArrowDataType::Int32, true)), + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), 2, )); assert_eq!(array.values().len(), 0); @@ -69,7 +69,10 @@ fn empty() { #[test] fn null() { let array = FixedSizeListArray::new_null( - ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2, + ), 2, ); assert_eq!(array.values().len(), 4); @@ -80,7 +83,10 @@ fn null() { fn wrong_size() { let values = Int32Array::from_slice([10, 20, 0]); assert!(FixedSizeListArray::try_new( - ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2 + ), values.boxed(), None ) @@ -91,7 +97,10 @@ fn wrong_size() { fn wrong_len() { let values = Int32Array::from_slice([10, 20, 0]); assert!(FixedSizeListArray::try_new( - ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Int32, true)), + 2 + ), values.boxed(), Some([true, false, false].into()), // it should be 2 ) @@ -99,7 +108,7 @@ fn wrong_len() { } #[test] -fn wrong_data_type() { +fn wrong_dtype() { let values = Int32Array::from_slice([10, 20, 0]); assert!(FixedSizeListArray::try_new( ArrowDataType::Binary, diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs index 23ea53231059..dc42b85c7e87 100644 --- a/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs @@ -36,7 +36,7 @@ fn new_with_field() { let mut list = MutableFixedSizeListArray::new_with_field( MutablePrimitiveArray::::new(), - "custom_items", + "custom_items".into(), false, 3, ); @@ -44,9 +44,13 @@ fn new_with_field() { let list: FixedSizeListArray = list.into(); assert_eq!( - list.data_type(), + list.dtype(), &ArrowDataType::FixedSizeList( - Box::new(Field::new("custom_items", ArrowDataType::Int32, false)), + Box::new(Field::new( + "custom_items".into(), + ArrowDataType::Int32, + false + )), 3 ) ); diff --git a/crates/polars/tests/it/arrow/array/growable/dictionary.rs b/crates/polars/tests/it/arrow/array/growable/dictionary.rs index e2a48275d7ae..c84d95113d7c 100644 --- a/crates/polars/tests/it/arrow/array/growable/dictionary.rs +++ b/crates/polars/tests/it/arrow/array/growable/dictionary.rs @@ -7,14 +7,14 @@ fn test_single() -> PolarsResult<()> { let original_data = vec![Some("a"), Some("b"), Some("a")]; let data = original_data.clone(); - let mut array = MutableDictionaryArray::>::new(); + let mut array = MutableDictionaryArray::>::new(); array.try_extend(data)?; let array = array.into(); // same values, less keys let expected = DictionaryArray::try_from_keys( PrimitiveArray::from_vec(vec![1, 0]), - Box::new(Utf8Array::::from(&original_data)), + Box::new(Utf8ViewArray::from_slice(&original_data)), ) .unwrap(); @@ -39,11 +39,11 @@ fn test_multi() -> PolarsResult<()> { let data1 = original_data1.clone(); let data2 = original_data2.clone(); - let mut array1 = MutableDictionaryArray::>::new(); + let mut array1 = MutableDictionaryArray::>::new(); array1.try_extend(data1)?; let array1: DictionaryArray = array1.into(); - let mut array2 = MutableDictionaryArray::>::new(); + let mut array2 = MutableDictionaryArray::>::new(); array2.try_extend(data2)?; let array2: DictionaryArray = array2.into(); @@ -51,7 +51,7 @@ fn test_multi() -> PolarsResult<()> { original_data1.extend(original_data2.iter().cloned()); let expected = DictionaryArray::try_from_keys( PrimitiveArray::from(&[Some(1), None, Some(3), None]), - Utf8Array::::from_slice(["a", "b", "c", "b", "a"]).boxed(), + Utf8ViewArray::from_slice_values(["a", "b", "c", "b", "a"]).boxed(), ) .unwrap(); diff --git a/crates/polars/tests/it/arrow/array/growable/list.rs b/crates/polars/tests/it/arrow/array/growable/list.rs index 1bc0985ceb4f..faa1286a564c 100644 --- a/crates/polars/tests/it/arrow/array/growable/list.rs +++ b/crates/polars/tests/it/arrow/array/growable/list.rs @@ -18,10 +18,9 @@ fn extension() { let array = create_list_array(data); - let data_type = - ArrowDataType::Extension("ext".to_owned(), Box::new(array.data_type().clone()), None); + let dtype = ArrowDataType::Extension("ext".into(), Box::new(array.dtype().clone()), None); let array_ext = ListArray::new( - data_type, + dtype, array.offsets().clone(), array.values().clone(), array.validity().cloned(), @@ -34,7 +33,7 @@ fn extension() { assert_eq!(a.len(), 1); let result: ListArray = a.into(); - assert_eq!(array_ext.data_type(), result.data_type()); + assert_eq!(array_ext.dtype(), result.dtype()); } #[test] diff --git a/crates/polars/tests/it/arrow/array/growable/mod.rs b/crates/polars/tests/it/arrow/array/growable/mod.rs index 43496a1e95b1..4510fd0749cd 100644 --- a/crates/polars/tests/it/arrow/array/growable/mod.rs +++ b/crates/polars/tests/it/arrow/array/growable/mod.rs @@ -18,12 +18,6 @@ fn test_make_growable() { let array = Int32Array::from_slice([1, 2]); make_growable(&[&array], false, 2); - let array = Utf8Array::::from_slice(["a", "aa"]); - make_growable(&[&array], false, 2); - - let array = Utf8Array::::from_slice(["a", "aa"]); - make_growable(&[&array], false, 2); - let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); make_growable(&[&array], false, 2); @@ -50,26 +44,25 @@ fn test_make_growable_extension() { .unwrap(); make_growable(&[&array], false, 2); - let data_type = - ArrowDataType::Extension("ext".to_owned(), Box::new(ArrowDataType::Int32), None); - let array = Int32Array::from_slice([1, 2]).to(data_type.clone()); + let dtype = ArrowDataType::Extension("ext".into(), Box::new(ArrowDataType::Int32), None); + let array = Int32Array::from_slice([1, 2]).to(dtype.clone()); let array_grown = make_growable(&[&array], false, 2).as_box(); - assert_eq!(array_grown.data_type(), &data_type); + assert_eq!(array_grown.dtype(), &dtype); - let data_type = ArrowDataType::Extension( - "ext".to_owned(), + let dtype = ArrowDataType::Extension( + "ext".into(), Box::new(ArrowDataType::Struct(vec![Field::new( - "a", + "a".into(), ArrowDataType::Int32, false, )])), None, ); let array = StructArray::new( - data_type.clone(), + dtype.clone(), vec![Int32Array::from_slice([1, 2]).boxed()], None, ); let array_grown = make_growable(&[&array], false, 2).as_box(); - assert_eq!(array_grown.data_type(), &data_type); + assert_eq!(array_grown.dtype(), &dtype); } diff --git a/crates/polars/tests/it/arrow/array/growable/struct_.rs b/crates/polars/tests/it/arrow/array/growable/struct_.rs index 809e70749f09..07f0403ee294 100644 --- a/crates/polars/tests/it/arrow/array/growable/struct_.rs +++ b/crates/polars/tests/it/arrow/array/growable/struct_.rs @@ -1,10 +1,10 @@ use arrow::array::growable::{Growable, GrowableStruct}; -use arrow::array::{Array, PrimitiveArray, StructArray, Utf8Array}; +use arrow::array::{Array, PrimitiveArray, StructArray, Utf8ViewArray}; use arrow::bitmap::Bitmap; use arrow::datatypes::{ArrowDataType, Field}; fn some_values() -> (ArrowDataType, Vec>) { - let strings: Box = Box::new(Utf8Array::::from([ + let strings: Box = Box::new(Utf8ViewArray::from_slice([ Some("a"), Some("aa"), None, @@ -19,8 +19,8 @@ fn some_values() -> (ArrowDataType, Vec>) { Some(5), ])); let fields = vec![ - Field::new("f1", ArrowDataType::Utf8, true), - Field::new("f2", ArrowDataType::Int32, true), + Field::new("f1".into(), ArrowDataType::Utf8View, true), + Field::new("f2".into(), ArrowDataType::Int32, true), ]; (ArrowDataType::Struct(fields), vec![strings, ints]) } @@ -115,7 +115,7 @@ fn many() { assert_eq!(mutable.len(), 5); let result = mutable.as_box(); - let expected_string: Box = Box::new(Utf8Array::::from([ + let expected_string: Box = Box::new(Utf8ViewArray::from_slice([ Some("aa"), None, Some("a"), diff --git a/crates/polars/tests/it/arrow/array/list/mod.rs b/crates/polars/tests/it/arrow/array/list/mod.rs index c65d7810de0a..37ab5d0e7e91 100644 --- a/crates/polars/tests/it/arrow/array/list/mod.rs +++ b/crates/polars/tests/it/arrow/array/list/mod.rs @@ -9,9 +9,9 @@ fn debug() { let values = Buffer::from(vec![1, 2, 3, 4, 5]); let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let array = ListArray::::new( - data_type, + dtype, vec![0, 2, 2, 3, 5].try_into().unwrap(), Box::new(values), None, @@ -25,9 +25,9 @@ fn split_at() { let values = Buffer::from(vec![1, 2, 3, 4, 5]); let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let array = ListArray::::new( - data_type, + dtype, vec![0, 2, 2, 3, 5].try_into().unwrap(), Box::new(values), None, @@ -45,9 +45,9 @@ fn test_nested_panic() { let values = Buffer::from(vec![1, 2, 3, 4, 5]); let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let array = ListArray::::new( - data_type.clone(), + dtype.clone(), vec![0, 2, 2, 3, 5].try_into().unwrap(), Box::new(values), None, @@ -56,7 +56,7 @@ fn test_nested_panic() { // The datatype for the nested array has to be created considering // the nested structure of the child data let _ = ListArray::::new( - data_type, + dtype, vec![0, 2, 4].try_into().unwrap(), Box::new(array), None, @@ -68,17 +68,17 @@ fn test_nested_display() { let values = Buffer::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let array = ListArray::::new( - data_type, + dtype, vec![0, 2, 4, 7, 7, 8, 10].try_into().unwrap(), Box::new(values), None, ); - let data_type = ListArray::::default_datatype(array.data_type().clone()); + let dtype = ListArray::::default_datatype(array.dtype().clone()); let nested = ListArray::::new( - data_type, + dtype, vec![0, 2, 5, 6].try_into().unwrap(), Box::new(array), None, diff --git a/crates/polars/tests/it/arrow/array/list/mutable.rs b/crates/polars/tests/it/arrow/array/list/mutable.rs index 2d4ba0c4d2f1..6b4c60e3b459 100644 --- a/crates/polars/tests/it/arrow/array/list/mutable.rs +++ b/crates/polars/tests/it/arrow/array/list/mutable.rs @@ -21,9 +21,9 @@ fn basics() { Some(Bitmap::from([true, true, true, true, false, true])), ); - let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let dtype = ListArray::::default_datatype(ArrowDataType::Int32); let expected = ListArray::::new( - data_type, + dtype, vec![0, 3, 3, 6].try_into().unwrap(), Box::new(values), Some(Bitmap::from([true, false, true])), diff --git a/crates/polars/tests/it/arrow/array/map/mod.rs b/crates/polars/tests/it/arrow/array/map/mod.rs index a9e0fa62b317..34e880578659 100644 --- a/crates/polars/tests/it/arrow/array/map/mod.rs +++ b/crates/polars/tests/it/arrow/array/map/mod.rs @@ -3,13 +3,13 @@ use arrow::datatypes::{ArrowDataType, Field}; fn dt() -> ArrowDataType { ArrowDataType::Struct(vec![ - Field::new("a", ArrowDataType::Utf8, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Utf8, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]) } fn array() -> MapArray { - let data_type = ArrowDataType::Map(Box::new(Field::new("a", dt(), true)), false); + let dtype = ArrowDataType::Map(Box::new(Field::new("a".into(), dt(), true)), false); let field = StructArray::new( dt(), @@ -21,7 +21,7 @@ fn array() -> MapArray { ); MapArray::new( - data_type, + dtype, vec![0, 1, 2, 3].try_into().unwrap(), Box::new(field), None, diff --git a/crates/polars/tests/it/arrow/array/mod.rs b/crates/polars/tests/it/arrow/array/mod.rs index 2dcb32ea6708..ccdf42f3b7c6 100644 --- a/crates/polars/tests/it/arrow/array/mod.rs +++ b/crates/polars/tests/it/arrow/array/mod.rs @@ -24,7 +24,11 @@ fn nulls() { ArrowDataType::Float64, ArrowDataType::Utf8, ArrowDataType::Binary, - ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), ]; let a = datatypes .into_iter() @@ -34,12 +38,12 @@ fn nulls() { // unions' null count is always 0 let datatypes = vec![ ArrowDataType::Union( - vec![Field::new("a", ArrowDataType::Binary, true)], + vec![Field::new("a".into(), ArrowDataType::Binary, true)], None, UnionMode::Dense, ), ArrowDataType::Union( - vec![Field::new("a", ArrowDataType::Binary, true)], + vec![Field::new("a".into(), ArrowDataType::Binary, true)], None, UnionMode::Sparse, ), @@ -57,23 +61,27 @@ fn empty() { ArrowDataType::Float64, ArrowDataType::Utf8, ArrowDataType::Binary, - ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), ArrowDataType::List(Box::new(Field::new( - "a", - ArrowDataType::Extension("ext".to_owned(), Box::new(ArrowDataType::Int32), None), + "a".into(), + ArrowDataType::Binary, + true, + ))), + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Extension("ext".into(), Box::new(ArrowDataType::Int32), None), true, ))), ArrowDataType::Union( - vec![Field::new("a", ArrowDataType::Binary, true)], + vec![Field::new("a".into(), ArrowDataType::Binary, true)], None, UnionMode::Sparse, ), ArrowDataType::Union( - vec![Field::new("a", ArrowDataType::Binary, true)], + vec![Field::new("a".into(), ArrowDataType::Binary, true)], None, UnionMode::Dense, ), - ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Int32, true)]), + ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Int32, true)]), ]; let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0); assert!(a); @@ -86,25 +94,29 @@ fn empty_extension() { ArrowDataType::Float64, ArrowDataType::Utf8, ArrowDataType::Binary, - ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), ArrowDataType::Union( - vec![Field::new("a", ArrowDataType::Binary, true)], + vec![Field::new("a".into(), ArrowDataType::Binary, true)], None, UnionMode::Sparse, ), ArrowDataType::Union( - vec![Field::new("a", ArrowDataType::Binary, true)], + vec![Field::new("a".into(), ArrowDataType::Binary, true)], None, UnionMode::Dense, ), - ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Int32, true)]), + ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Int32, true)]), ]; let a = datatypes .into_iter() - .map(|dt| ArrowDataType::Extension("ext".to_owned(), Box::new(dt), None)) + .map(|dt| ArrowDataType::Extension("ext".into(), Box::new(dt), None)) .all(|x| { let a = new_empty_array(x); - a.len() == 0 && matches!(a.data_type(), ArrowDataType::Extension(_, _, _)) + a.len() == 0 && matches!(a.dtype(), ArrowDataType::Extension(_, _, _)) }); assert!(a); } @@ -116,7 +128,11 @@ fn test_clone() { ArrowDataType::Float64, ArrowDataType::Utf8, ArrowDataType::Binary, - ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Binary, + true, + ))), ]; let a = datatypes .into_iter() diff --git a/crates/polars/tests/it/arrow/array/primitive/fmt.rs b/crates/polars/tests/it/arrow/array/primitive/fmt.rs index e670bc93fe7b..eb6c067b9ec2 100644 --- a/crates/polars/tests/it/arrow/array/primitive/fmt.rs +++ b/crates/polars/tests/it/arrow/array/primitive/fmt.rs @@ -117,7 +117,7 @@ fn debug_timestamp_ns() { fn debug_timestamp_tz_ns() { let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( TimeUnit::Nanosecond, - Some("+02:00".to_string()), + Some("+02:00".into()), )); assert_eq!( format!("{array:?}"), @@ -129,7 +129,7 @@ fn debug_timestamp_tz_ns() { fn debug_timestamp_tz_not_parsable() { let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( TimeUnit::Nanosecond, - Some("aa".to_string()), + Some("aa".into()), )); assert_eq!( format!("{array:?}"), @@ -142,7 +142,7 @@ fn debug_timestamp_tz_not_parsable() { fn debug_timestamp_tz1_ns() { let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( TimeUnit::Nanosecond, - Some("Europe/Lisbon".to_string()), + Some("Europe/Lisbon".into()), )); assert_eq!( format!("{array:?}"), diff --git a/crates/polars/tests/it/arrow/array/primitive/mod.rs b/crates/polars/tests/it/arrow/array/primitive/mod.rs index d19722d7fb74..d0630c854103 100644 --- a/crates/polars/tests/it/arrow/array/primitive/mod.rs +++ b/crates/polars/tests/it/arrow/array/primitive/mod.rs @@ -103,7 +103,7 @@ fn months_days_ns_from_slice() { } #[test] -fn wrong_data_type() { +fn wrong_dtype() { let values = Buffer::from(b"abbb".to_vec()); assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, None).is_err()); } diff --git a/crates/polars/tests/it/arrow/array/primitive/mutable.rs b/crates/polars/tests/it/arrow/array/primitive/mutable.rs index bd4d3831dc82..4bc29890de07 100644 --- a/crates/polars/tests/it/arrow/array/primitive/mutable.rs +++ b/crates/polars/tests/it/arrow/array/primitive/mutable.rs @@ -33,7 +33,7 @@ fn to() { ) .unwrap(); let a = a.to(ArrowDataType::Date32); - assert_eq!(a.data_type(), &ArrowDataType::Date32); + assert_eq!(a.dtype(), &ArrowDataType::Date32); } #[test] @@ -311,7 +311,7 @@ fn try_from_trusted_len_iter() { } #[test] -fn wrong_data_type() { +fn wrong_dtype() { assert!(MutablePrimitiveArray::::try_new(ArrowDataType::Utf8, vec![], None).is_err()); } diff --git a/crates/polars/tests/it/arrow/array/struct_/iterator.rs b/crates/polars/tests/it/arrow/array/struct_/iterator.rs index 5b4b0b784d13..e4b6a7691ad0 100644 --- a/crates/polars/tests/it/arrow/array/struct_/iterator.rs +++ b/crates/polars/tests/it/arrow/array/struct_/iterator.rs @@ -8,8 +8,8 @@ fn test_simple_iter() { let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); let fields = vec![ - Field::new("b", ArrowDataType::Boolean, false), - Field::new("c", ArrowDataType::Int32, false), + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), ]; let array = StructArray::new( diff --git a/crates/polars/tests/it/arrow/array/struct_/mod.rs b/crates/polars/tests/it/arrow/array/struct_/mod.rs index 5af6556096bc..bd1a1c83086c 100644 --- a/crates/polars/tests/it/arrow/array/struct_/mod.rs +++ b/crates/polars/tests/it/arrow/array/struct_/mod.rs @@ -10,8 +10,8 @@ fn array() -> StructArray { let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); let fields = vec![ - Field::new("b", ArrowDataType::Boolean, false), - Field::new("c", ArrowDataType::Int32, false), + Field::new("b".into(), ArrowDataType::Boolean, false), + Field::new("c".into(), ArrowDataType::Int32, false), ]; StructArray::new( diff --git a/crates/polars/tests/it/arrow/array/struct_/mutable.rs b/crates/polars/tests/it/arrow/array/struct_/mutable.rs index e9d698aa1bb3..4a526a76391b 100644 --- a/crates/polars/tests/it/arrow/array/struct_/mutable.rs +++ b/crates/polars/tests/it/arrow/array/struct_/mutable.rs @@ -5,8 +5,8 @@ use arrow::datatypes::{ArrowDataType, Field}; fn push() { let c1 = Box::new(MutablePrimitiveArray::::new()) as Box; let values = vec![c1]; - let data_type = ArrowDataType::Struct(vec![Field::new("f1", ArrowDataType::Int32, true)]); - let mut a = MutableStructArray::new(data_type, values); + let dtype = ArrowDataType::Struct(vec![Field::new("f1".into(), ArrowDataType::Int32, true)]); + let mut a = MutableStructArray::new(dtype, values); a.value::>(0) .unwrap() diff --git a/crates/polars/tests/it/arrow/array/union.rs b/crates/polars/tests/it/arrow/array/union.rs index b358aa8e44bb..3a3939349fb7 100644 --- a/crates/polars/tests/it/arrow/array/union.rs +++ b/crates/polars/tests/it/arrow/array/union.rs @@ -20,17 +20,17 @@ where #[test] fn sparse_debug() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let types = vec![0, 0, 1].into(); let fields = vec![ Int32Array::from(&[Some(1), None, Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type, types, fields, None); + let array = UnionArray::new(dtype, types, fields, None); assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); @@ -40,10 +40,10 @@ fn sparse_debug() -> PolarsResult<()> { #[test] fn dense_debug() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Dense); let types = vec![0, 0, 1].into(); let fields = vec![ Int32Array::from(&[Some(1), None, Some(2)]).boxed(), @@ -51,7 +51,7 @@ fn dense_debug() -> PolarsResult<()> { ]; let offsets = Some(vec![0, 1, 0].into()); - let array = UnionArray::new(data_type, types, fields, offsets); + let array = UnionArray::new(dtype, types, fields, offsets); assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); @@ -61,17 +61,17 @@ fn dense_debug() -> PolarsResult<()> { #[test] fn slice() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::LargeUtf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::LargeUtf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let types = Buffer::from(vec![0, 0, 1]); let fields = vec![ Int32Array::from(&[Some(1), None, Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type.clone(), types, fields.clone(), None); + let array = UnionArray::new(dtype.clone(), types, fields.clone(), None); let result = array.sliced(1, 2); @@ -80,7 +80,7 @@ fn slice() -> PolarsResult<()> { Int32Array::from(&[None, Some(2)]).boxed(), Utf8Array::::from([Some("b"), Some("c")]).boxed(), ]; - let expected = UnionArray::new(data_type, sliced_types, sliced_fields, None); + let expected = UnionArray::new(dtype, sliced_types, sliced_fields, None); assert_eq!(expected, result); Ok(()) @@ -89,17 +89,17 @@ fn slice() -> PolarsResult<()> { #[test] fn iter_sparse() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let types = Buffer::from(vec![0, 0, 1]); let fields = vec![ Int32Array::from(&[Some(1), None, Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type, types, fields.clone(), None); + let array = UnionArray::new(dtype, types, fields.clone(), None); let mut iter = array.iter(); assert_eq!( @@ -122,10 +122,10 @@ fn iter_sparse() -> PolarsResult<()> { #[test] fn iter_dense() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Dense); let types = Buffer::from(vec![0, 0, 1]); let offsets = Buffer::::from(vec![0, 1, 0]); let fields = vec![ @@ -133,7 +133,7 @@ fn iter_dense() -> PolarsResult<()> { Utf8Array::::from([Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let array = UnionArray::new(dtype, types, fields.clone(), Some(offsets)); let mut iter = array.iter(); assert_eq!( @@ -156,17 +156,17 @@ fn iter_dense() -> PolarsResult<()> { #[test] fn iter_sparse_slice() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let types = Buffer::from(vec![0, 0, 1]); let fields = vec![ Int32Array::from(&[Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type, types, fields.clone(), None); + let array = UnionArray::new(dtype, types, fields.clone(), None); let array_slice = array.sliced(1, 1); let mut iter = array_slice.iter(); @@ -182,10 +182,10 @@ fn iter_sparse_slice() -> PolarsResult<()> { #[test] fn iter_dense_slice() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Dense); let types = Buffer::from(vec![0, 0, 1]); let offsets = Buffer::::from(vec![0, 1, 0]); let fields = vec![ @@ -193,7 +193,7 @@ fn iter_dense_slice() -> PolarsResult<()> { Utf8Array::::from([Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let array = UnionArray::new(dtype, types, fields.clone(), Some(offsets)); let array_slice = array.sliced(1, 1); let mut iter = array_slice.iter(); @@ -209,10 +209,10 @@ fn iter_dense_slice() -> PolarsResult<()> { #[test] fn scalar() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Dense); let types = Buffer::from(vec![0, 0, 1]); let offsets = Buffer::::from(vec![0, 1, 0]); let fields = vec![ @@ -220,7 +220,7 @@ fn scalar() -> PolarsResult<()> { Utf8Array::::from([Some("c")]).boxed(), ]; - let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let array = UnionArray::new(dtype, types, fields.clone(), Some(offsets)); let scalar = new_scalar(&array, 0); let union_scalar = scalar.as_any().downcast_ref::().unwrap(); @@ -266,42 +266,42 @@ fn scalar() -> PolarsResult<()> { #[test] fn dense_without_offsets_is_error() { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Dense); let types = vec![0, 0, 1].into(); let fields = vec![ Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), ]; - assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); + assert!(UnionArray::try_new(dtype, types, fields.clone(), None).is_err()); } #[test] fn fields_must_match() { let fields = vec![ - Field::new("a", ArrowDataType::Int64, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int64, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let types = vec![0, 0, 1].into(); let fields = vec![ Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), ]; - assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); + assert!(UnionArray::try_new(dtype, types, fields.clone(), None).is_err()); } #[test] fn sparse_with_offsets_is_error() { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let fields = vec![ Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), @@ -310,16 +310,16 @@ fn sparse_with_offsets_is_error() { let types = vec![0, 0, 1].into(); let offsets = vec![0, 1, 0].into(); - assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); + assert!(UnionArray::try_new(dtype, types, fields.clone(), Some(offsets)).is_err()); } #[test] fn offsets_must_be_in_bounds() { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let fields = vec![ Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), @@ -329,16 +329,16 @@ fn offsets_must_be_in_bounds() { // it must be equal to length og types let offsets = vec![0, 1].into(); - assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); + assert!(UnionArray::try_new(dtype, types, fields.clone(), Some(offsets)).is_err()); } #[test] fn sparse_with_wrong_offsets1_is_error() { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let fields = vec![ Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), @@ -348,16 +348,16 @@ fn sparse_with_wrong_offsets1_is_error() { // it must be equal to length of types let offsets = vec![0, 1, 10].into(); - assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); + assert!(UnionArray::try_new(dtype, types, fields.clone(), Some(offsets)).is_err()); } #[test] fn types_must_be_in_bounds() -> PolarsResult<()> { let fields = vec![ - Field::new("a", ArrowDataType::Int32, true), - Field::new("b", ArrowDataType::Utf8, true), + Field::new("a".into(), ArrowDataType::Int32, true), + Field::new("b".into(), ArrowDataType::Utf8, true), ]; - let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let dtype = ArrowDataType::Union(fields, None, UnionMode::Sparse); let fields = vec![ Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), @@ -366,6 +366,6 @@ fn types_must_be_in_bounds() -> PolarsResult<()> { // 10 > num fields let types = vec![0, 10].into(); - assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); + assert!(UnionArray::try_new(dtype, types, fields.clone(), None).is_err()); Ok(()) } diff --git a/crates/polars/tests/it/arrow/array/utf8/mod.rs b/crates/polars/tests/it/arrow/array/utf8/mod.rs index 9f7b302e3d99..89774dff4ae9 100644 --- a/crates/polars/tests/it/arrow/array/utf8/mod.rs +++ b/crates/polars/tests/it/arrow/array/utf8/mod.rs @@ -148,7 +148,7 @@ fn not_utf8_individually() { } #[test] -fn wrong_data_type() { +fn wrong_dtype() { let offsets = vec![0, 4].try_into().unwrap(); let values = b"abbb".to_vec().into(); assert!(Utf8Array::::try_new(ArrowDataType::Int32, offsets, values, None).is_err()); diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable.rs b/crates/polars/tests/it/arrow/array/utf8/mutable.rs index 8db873a90d10..7f1725957085 100644 --- a/crates/polars/tests/it/arrow/array/utf8/mutable.rs +++ b/crates/polars/tests/it/arrow/array/utf8/mutable.rs @@ -74,7 +74,7 @@ fn not_utf8() { } #[test] -fn wrong_data_type() { +fn wrong_dtype() { let offsets = vec![0, 4].try_into().unwrap(); let values = vec![1, 2, 3, 4]; assert!(MutableUtf8Array::::try_new(ArrowDataType::Int8, offsets, values, None).is_err()); diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs index d4a309949934..d70316e18950 100644 --- a/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs +++ b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs @@ -19,7 +19,7 @@ fn offsets_must_be_in_bounds() { } #[test] -fn data_type_must_be_consistent() { +fn dtype_must_be_consistent() { let offsets = vec![0, 4].try_into().unwrap(); let values = b"abbb".to_vec(); assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err()); diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs index 3f31240b8602..45e19d194a46 100644 --- a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -22,11 +22,11 @@ fn utf8() { #[test] fn fixed_size_list() { - let data_type = ArrowDataType::FixedSizeList( - Box::new(Field::new("elem", ArrowDataType::Float32, false)), + let dtype = ArrowDataType::FixedSizeList( + Box::new(Field::new("elem".into(), ArrowDataType::Float32, false)), 3, ); let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); - let a = FixedSizeListArray::new(data_type, values, None); + let a = FixedSizeListArray::new(dtype, values, None); assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); } diff --git a/crates/polars/tests/it/arrow/ffi/data.rs b/crates/polars/tests/it/arrow/ffi/data.rs index bb798a1bc4fc..24caed1c482e 100644 --- a/crates/polars/tests/it/arrow/ffi/data.rs +++ b/crates/polars/tests/it/arrow/ffi/data.rs @@ -4,16 +4,15 @@ use arrow::ffi; use polars_error::PolarsResult; fn _test_round_trip(array: Box, expected: Box) -> PolarsResult<()> { - let field = Field::new("a", array.data_type().clone(), true); + let field = Field::new("a".into(), array.dtype().clone(), true); - // export array and corresponding data_type + // export array and corresponding dtype let array_ffi = ffi::export_array_to_c(array); let schema_ffi = ffi::export_field_to_c(&field); // import references let result_field = unsafe { ffi::import_field_from_c(&schema_ffi)? }; - let result_array = - unsafe { ffi::import_array_from_c(array_ffi, result_field.data_type.clone())? }; + let result_array = unsafe { ffi::import_array_from_c(array_ffi, result_field.dtype.clone())? }; assert_eq!(&result_array, &expected); assert_eq!(result_field, field); diff --git a/crates/polars/tests/it/arrow/ffi/stream.rs b/crates/polars/tests/it/arrow/ffi/stream.rs index f949fdf4c88e..542ce1c1e0fb 100644 --- a/crates/polars/tests/it/arrow/ffi/stream.rs +++ b/crates/polars/tests/it/arrow/ffi/stream.rs @@ -4,7 +4,7 @@ use arrow::ffi; use polars_error::{PolarsError, PolarsResult}; fn _test_round_trip(arrays: Vec>) -> PolarsResult<()> { - let field = Field::new("a", arrays[0].data_type().clone(), true); + let field = Field::new("a".into(), arrays[0].dtype().clone(), true); let iter = Box::new(arrays.clone().into_iter().map(Ok)) as _; let mut stream = Box::new(ffi::ArrowArrayStream::empty()); diff --git a/crates/polars/tests/it/arrow/io/ipc/mod.rs b/crates/polars/tests/it/arrow/io/ipc/mod.rs index c5b622305d42..3dfd0aedd276 100644 --- a/crates/polars/tests/it/arrow/io/ipc/mod.rs +++ b/crates/polars/tests/it/arrow/io/ipc/mod.rs @@ -7,6 +7,7 @@ use arrow::io::ipc::read::{read_file_metadata, FileReader}; use arrow::io::ipc::write::*; use arrow::io::ipc::IpcField; use arrow::record_batch::RecordBatchT; +use polars::prelude::PlSmallStr; use polars_error::*; pub(crate) fn write( @@ -49,8 +50,12 @@ fn round_trip( } fn prep_schema(array: &dyn Array) -> ArrowSchemaRef { - let fields = vec![Field::new("a", array.data_type().clone(), true)]; - Arc::new(ArrowSchema::from(fields)) + let name = PlSmallStr::from_static("a"); + Arc::new(ArrowSchema::from_iter([Field::new( + name, + array.dtype().clone(), + true, + )])) } #[test] diff --git a/crates/polars/tests/it/arrow/scalar/binary.rs b/crates/polars/tests/it/arrow/scalar/binary.rs index d1b3e984d379..09cb8fba254b 100644 --- a/crates/polars/tests/it/arrow/scalar/binary.rs +++ b/crates/polars/tests/it/arrow/scalar/binary.rs @@ -19,12 +19,12 @@ fn basics() { let a = BinaryScalar::::from(Some("a")); assert_eq!(a.value(), Some(b"a".as_ref())); - assert_eq!(a.data_type(), &ArrowDataType::Binary); + assert_eq!(a.dtype(), &ArrowDataType::Binary); assert!(a.is_valid()); let a = BinaryScalar::::from(None::<&str>); - assert_eq!(a.data_type(), &ArrowDataType::LargeBinary); + assert_eq!(a.dtype(), &ArrowDataType::LargeBinary); assert!(!a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/boolean.rs b/crates/polars/tests/it/arrow/scalar/boolean.rs index 7c400b0fde3e..76128d7355bd 100644 --- a/crates/polars/tests/it/arrow/scalar/boolean.rs +++ b/crates/polars/tests/it/arrow/scalar/boolean.rs @@ -19,7 +19,7 @@ fn basics() { let a = BooleanScalar::new(Some(true)); assert_eq!(a.value(), Some(true)); - assert_eq!(a.data_type(), &ArrowDataType::Boolean); + assert_eq!(a.dtype(), &ArrowDataType::Boolean); assert!(a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs index c83bc4d69749..779e80e86c37 100644 --- a/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs @@ -19,7 +19,7 @@ fn basics() { let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); assert_eq!(a.value(), Some(b"a".as_ref())); - assert_eq!(a.data_type(), &ArrowDataType::FixedSizeBinary(1)); + assert_eq!(a.dtype(), &ArrowDataType::FixedSizeBinary(1)); assert!(a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs index 2aa6f45bbd74..eb4084792c33 100644 --- a/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs @@ -5,8 +5,10 @@ use arrow::scalar::{FixedSizeListScalar, Scalar}; #[allow(clippy::eq_op)] #[test] fn equal() { - let dt = - ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Boolean, true)), 2); + let dt = ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Boolean, true)), + 2, + ); let a = FixedSizeListScalar::new( dt.clone(), Some(BooleanArray::from_slice([true, false]).boxed()), @@ -25,8 +27,10 @@ fn equal() { #[test] fn basics() { - let dt = - ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Boolean, true)), 2); + let dt = ArrowDataType::FixedSizeList( + Box::new(Field::new("a".into(), ArrowDataType::Boolean, true)), + 2, + ); let a = FixedSizeListScalar::new( dt.clone(), Some(BooleanArray::from_slice([true, false]).boxed()), @@ -36,7 +40,7 @@ fn basics() { BooleanArray::from_slice([true, false]), a.values().unwrap().as_ref() ); - assert_eq!(a.data_type(), &dt); + assert_eq!(a.dtype(), &dt); assert!(a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/list.rs b/crates/polars/tests/it/arrow/scalar/list.rs index 7cd2938237c9..d8acce251fd1 100644 --- a/crates/polars/tests/it/arrow/scalar/list.rs +++ b/crates/polars/tests/it/arrow/scalar/list.rs @@ -5,7 +5,11 @@ use arrow::scalar::{ListScalar, Scalar}; #[allow(clippy::eq_op)] #[test] fn equal() { - let dt = ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Boolean, true))); + let dt = ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Boolean, + true, + ))); let a = ListScalar::::new( dt.clone(), Some(BooleanArray::from_slice([true, false]).boxed()), @@ -21,14 +25,18 @@ fn equal() { #[test] fn basics() { - let dt = ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Boolean, true))); + let dt = ArrowDataType::List(Box::new(Field::new( + "a".into(), + ArrowDataType::Boolean, + true, + ))); let a = ListScalar::::new( dt.clone(), Some(BooleanArray::from_slice([true, false]).boxed()), ); assert_eq!(BooleanArray::from_slice([true, false]), a.values().as_ref()); - assert_eq!(a.data_type(), &dt); + assert_eq!(a.dtype(), &dt); assert!(a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/map.rs b/crates/polars/tests/it/arrow/scalar/map.rs index e9f0ede0784f..ee23cb47960f 100644 --- a/crates/polars/tests/it/arrow/scalar/map.rs +++ b/crates/polars/tests/it/arrow/scalar/map.rs @@ -6,8 +6,8 @@ use arrow::scalar::{MapScalar, Scalar}; #[test] fn equal() { let kv_dt = ArrowDataType::Struct(vec![ - Field::new("key", ArrowDataType::Utf8, false), - Field::new("value", ArrowDataType::Boolean, true), + Field::new("key".into(), ArrowDataType::Utf8, false), + Field::new("value".into(), ArrowDataType::Boolean, true), ]); let kv_array1 = StructArray::try_new( kv_dt.clone(), @@ -28,7 +28,7 @@ fn equal() { ) .unwrap(); - let dt = ArrowDataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let dt = ArrowDataType::Map(Box::new(Field::new("entries".into(), kv_dt, true)), false); let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array1))); let b = MapScalar::new(dt.clone(), None); assert_eq!(a, a); @@ -42,8 +42,8 @@ fn equal() { #[test] fn basics() { let kv_dt = ArrowDataType::Struct(vec![ - Field::new("key", ArrowDataType::Utf8, false), - Field::new("value", ArrowDataType::Boolean, true), + Field::new("key".into(), ArrowDataType::Utf8, false), + Field::new("value".into(), ArrowDataType::Boolean, true), ]); let kv_array = StructArray::try_new( kv_dt.clone(), @@ -55,11 +55,11 @@ fn basics() { ) .unwrap(); - let dt = ArrowDataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let dt = ArrowDataType::Map(Box::new(Field::new("entries".into(), kv_dt, true)), false); let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array.clone()))); assert_eq!(kv_array, a.values().as_ref()); - assert_eq!(a.data_type(), &dt); + assert_eq!(a.dtype(), &dt); assert!(a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/null.rs b/crates/polars/tests/it/arrow/scalar/null.rs index 3ceaf69f83b6..25f68534aa87 100644 --- a/crates/polars/tests/it/arrow/scalar/null.rs +++ b/crates/polars/tests/it/arrow/scalar/null.rs @@ -12,7 +12,7 @@ fn equal() { fn basics() { let a = NullScalar::default(); - assert_eq!(a.data_type(), &ArrowDataType::Null); + assert_eq!(a.dtype(), &ArrowDataType::Null); assert!(!a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/primitive.rs b/crates/polars/tests/it/arrow/scalar/primitive.rs index 954a80147833..d9c2f04b4d37 100644 --- a/crates/polars/tests/it/arrow/scalar/primitive.rs +++ b/crates/polars/tests/it/arrow/scalar/primitive.rs @@ -19,18 +19,18 @@ fn basics() { let a = PrimitiveScalar::from(Some(2i32)); assert_eq!(a.value(), &Some(2i32)); - assert_eq!(a.data_type(), &ArrowDataType::Int32); + assert_eq!(a.dtype(), &ArrowDataType::Int32); let a = a.to(ArrowDataType::Date32); - assert_eq!(a.data_type(), &ArrowDataType::Date32); + assert_eq!(a.dtype(), &ArrowDataType::Date32); let a = PrimitiveScalar::::from(None); - assert_eq!(a.data_type(), &ArrowDataType::Int32); + assert_eq!(a.dtype(), &ArrowDataType::Int32); assert!(!a.is_valid()); let a = a.to(ArrowDataType::Date32); - assert_eq!(a.data_type(), &ArrowDataType::Date32); + assert_eq!(a.dtype(), &ArrowDataType::Date32); let _: &dyn std::any::Any = a.as_any(); } diff --git a/crates/polars/tests/it/arrow/scalar/struct_.rs b/crates/polars/tests/it/arrow/scalar/struct_.rs index 23461bb26568..1b4de73ef25c 100644 --- a/crates/polars/tests/it/arrow/scalar/struct_.rs +++ b/crates/polars/tests/it/arrow/scalar/struct_.rs @@ -4,7 +4,7 @@ use arrow::scalar::{BooleanScalar, Scalar, StructScalar}; #[allow(clippy::eq_op)] #[test] fn equal() { - let dt = ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Boolean, true)]); + let dt = ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Boolean, true)]); let a = StructScalar::new( dt.clone(), Some(vec![ @@ -27,14 +27,14 @@ fn equal() { #[test] fn basics() { - let dt = ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Boolean, true)]); + let dt = ArrowDataType::Struct(vec![Field::new("a".into(), ArrowDataType::Boolean, true)]); let values = vec![Box::new(BooleanScalar::from(Some(true))) as Box]; let a = StructScalar::new(dt.clone(), Some(values.clone())); assert_eq!(a.values(), &values); - assert_eq!(a.data_type(), &dt); + assert_eq!(a.dtype(), &dt); assert!(a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/arrow/scalar/utf8.rs b/crates/polars/tests/it/arrow/scalar/utf8.rs index bd7c6449d89c..249922bcd60c 100644 --- a/crates/polars/tests/it/arrow/scalar/utf8.rs +++ b/crates/polars/tests/it/arrow/scalar/utf8.rs @@ -19,12 +19,12 @@ fn basics() { let a = Utf8Scalar::::from(Some("a")); assert_eq!(a.value(), Some("a")); - assert_eq!(a.data_type(), &ArrowDataType::Utf8); + assert_eq!(a.dtype(), &ArrowDataType::Utf8); assert!(a.is_valid()); let a = Utf8Scalar::::from(None::<&str>); - assert_eq!(a.data_type(), &ArrowDataType::LargeUtf8); + assert_eq!(a.dtype(), &ArrowDataType::LargeUtf8); assert!(!a.is_valid()); let _: &dyn std::any::Any = a.as_any(); diff --git a/crates/polars/tests/it/chunks/parquet.rs b/crates/polars/tests/it/chunks/parquet.rs index 26c37566845a..384382fdd5f9 100644 --- a/crates/polars/tests/it/chunks/parquet.rs +++ b/crates/polars/tests/it/chunks/parquet.rs @@ -11,7 +11,7 @@ fn test_cast_join_14872() { let mut df2 = df![ "ints" => [0, 1], - "strings" => vec![Series::new("", ["a"]); 2], + "strings" => vec![Series::new("".into(), ["a"]); 2], ] .unwrap(); @@ -30,7 +30,7 @@ fn test_cast_join_14872() { let expected = df![ "ints" => [1], - "strings" => vec![Series::new("", ["a"]); 1], + "strings" => vec![Series::new("".into(), ["a"]); 1], ] .unwrap(); diff --git a/crates/polars/tests/it/core/date_like.rs b/crates/polars/tests/it/core/date_like.rs index df91aa512afc..7777d3fd1eb0 100644 --- a/crates/polars/tests/it/core/date_like.rs +++ b/crates/polars/tests/it/core/date_like.rs @@ -4,9 +4,9 @@ use super::*; #[cfg(feature = "dtype-datetime")] #[cfg_attr(miri, ignore)] fn test_datelike_join() -> PolarsResult<()> { - let s = Series::new("foo", &[1, 2, 3]); + let s = Series::new("foo".into(), &[1, 2, 3]); let mut s1 = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; - s1.rename("bar"); + s1.rename("bar".into()); let df = DataFrame::new(vec![s, s1])?; @@ -33,7 +33,7 @@ fn test_datelike_join() -> PolarsResult<()> { #[test] #[cfg(all(feature = "dtype-datetime", feature = "dtype-duration"))] fn test_datelike_methods() -> PolarsResult<()> { - let s = Series::new("foo", &[1, 2, 3]); + let s = Series::new("foo".into(), &[1, 2, 3]); let s = s.cast(&DataType::Datetime(TimeUnit::Nanoseconds, None))?; let out = s.subtract(&s)?; @@ -52,7 +52,7 @@ fn test_datelike_methods() -> PolarsResult<()> { #[test] #[cfg(all(feature = "dtype-datetime", feature = "dtype-duration"))] fn test_arithmetic_dispatch() { - let s = Int64Chunked::new("", &[1, 2, 3]) + let s = Int64Chunked::new("".into(), &[1, 2, 3]) .into_datetime(TimeUnit::Nanoseconds, None) .into_series(); @@ -113,13 +113,13 @@ fn test_arithmetic_dispatch() { #[test] #[cfg(feature = "dtype-duration")] fn test_duration() -> PolarsResult<()> { - let a = Int64Chunked::new("", &[1, 2, 3]) + let a = Int64Chunked::new("".into(), &[1, 2, 3]) .into_datetime(TimeUnit::Nanoseconds, None) .into_series(); - let b = Int64Chunked::new("", &[2, 3, 4]) + let b = Int64Chunked::new("".into(), &[2, 3, 4]) .into_datetime(TimeUnit::Nanoseconds, None) .into_series(); - let c = Int64Chunked::new("", &[1, 1, 1]) + let c = Int64Chunked::new("".into(), &[1, 1, 1]) .into_duration(TimeUnit::Nanoseconds) .into_series(); assert_eq!( @@ -132,7 +132,7 @@ fn test_duration() -> PolarsResult<()> { ); assert_eq!( b.subtract(&a)?, - Int64Chunked::full("", 1, a.len()) + Int64Chunked::full("".into(), 1, a.len()) .into_duration(TimeUnit::Nanoseconds) .into_series() ); @@ -142,8 +142,12 @@ fn test_duration() -> PolarsResult<()> { #[test] #[cfg(feature = "dtype-duration")] fn test_duration_date_arithmetic() -> PolarsResult<()> { - let date1 = Int32Chunked::new("", &[1, 1, 1]).into_date().into_series(); - let date2 = Int32Chunked::new("", &[2, 3, 4]).into_date().into_series(); + let date1 = Int32Chunked::new("".into(), &[1, 1, 1]) + .into_date() + .into_series(); + let date2 = Int32Chunked::new("".into(), &[2, 3, 4]) + .into_date() + .into_series(); let diff_ms = &date2 - &date1; let diff_ms = diff_ms?; diff --git a/crates/polars/tests/it/core/group_by.rs b/crates/polars/tests/it/core/group_by.rs index f14caad753dd..12241bc5b2eb 100644 --- a/crates/polars/tests/it/core/group_by.rs +++ b/crates/polars/tests/it/core/group_by.rs @@ -5,7 +5,10 @@ use super::*; #[test] fn test_sorted_group_by() -> PolarsResult<()> { // nulls last - let mut s = Series::new("a", &[Some(1), Some(1), Some(1), Some(6), Some(6), None]); + let mut s = Series::new( + "a".into(), + &[Some(1), Some(1), Some(1), Some(6), Some(6), None], + ); s.set_sorted_flag(IsSorted::Ascending); for mt in [true, false] { let out = s.group_tuples(mt, false)?; @@ -14,7 +17,7 @@ fn test_sorted_group_by() -> PolarsResult<()> { // nulls first let mut s = Series::new( - "a", + "a".into(), &[None, None, Some(1), Some(1), Some(1), Some(6), Some(6)], ); s.set_sorted_flag(IsSorted::Ascending); @@ -24,7 +27,10 @@ fn test_sorted_group_by() -> PolarsResult<()> { } // nulls last - let mut s = Series::new("a", &[Some(1), Some(1), Some(1), Some(6), Some(6), None]); + let mut s = Series::new( + "a".into(), + &[Some(1), Some(1), Some(1), Some(6), Some(6), None], + ); s.set_sorted_flag(IsSorted::Ascending); for mt in [true, false] { let out = s.group_tuples(mt, false)?; @@ -33,7 +39,7 @@ fn test_sorted_group_by() -> PolarsResult<()> { // nulls first descending sorted let mut s = Series::new( - "a", + "a".into(), &[ None, None, @@ -53,7 +59,7 @@ fn test_sorted_group_by() -> PolarsResult<()> { // nulls last descending sorted let mut s = Series::new( - "a", + "a".into(), &[ Some(15), Some(15), diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index aa5bfc415697..fe4ec8ba78cb 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -39,13 +39,13 @@ fn test_chunked_left_join() -> PolarsResult<()> { } fn create_frames() -> (DataFrame, DataFrame) { - let s0 = Series::new("days", &[0, 1, 2]); - let s1 = Series::new("temp", &[22.1, 19.9, 7.]); - let s2 = Series::new("rain", &[0.2, 0.1, 0.3]); + let s0 = Series::new("days".into(), &[0, 1, 2]); + let s1 = Series::new("temp".into(), &[22.1, 19.9, 7.]); + let s2 = Series::new("rain".into(), &[0.2, 0.1, 0.3]); let temp = DataFrame::new(vec![s0, s1, s2]).unwrap(); - let s0 = Series::new("days", &[1, 2, 3, 1]); - let s1 = Series::new("rain", &[0.1, 0.2, 0.3, 0.4]); + let s0 = Series::new("days".into(), &[1, 2, 3, 1]); + let s1 = Series::new("rain".into(), &[0.1, 0.2, 0.3, 0.4]); let rain = DataFrame::new(vec![s0, s1]).unwrap(); (temp, rain) } @@ -59,10 +59,10 @@ fn test_inner_join() { std::env::set_var("POLARS_MAX_THREADS", format!("{}", i)); let joined = temp.inner_join(&rain, ["days"], ["days"]).unwrap(); - let join_col_days = Series::new("days", &[1, 2, 1]); - let join_col_temp = Series::new("temp", &[19.9, 7., 19.9]); - let join_col_rain = Series::new("rain", &[0.1, 0.3, 0.1]); - let join_col_rain_right = Series::new("rain_right", [0.1, 0.2, 0.4].as_ref()); + let join_col_days = Series::new("days".into(), &[1, 2, 1]); + let join_col_temp = Series::new("temp".into(), &[19.9, 7., 19.9]); + let join_col_rain = Series::new("rain".into(), &[0.1, 0.3, 0.1]); + let join_col_rain_right = Series::new("rain_right".into(), [0.1, 0.2, 0.4].as_ref()); let true_df = DataFrame::new(vec![ join_col_days, join_col_temp, @@ -81,12 +81,12 @@ fn test_inner_join() { fn test_left_join() { for i in 1..8 { std::env::set_var("POLARS_MAX_THREADS", format!("{}", i)); - let s0 = Series::new("days", &[0, 1, 2, 3, 4]); - let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); + let s0 = Series::new("days".into(), &[0, 1, 2, 3, 4]); + let s1 = Series::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); let temp = DataFrame::new(vec![s0, s1]).unwrap(); - let s0 = Series::new("days", &[1, 2]); - let s1 = Series::new("rain", &[0.1, 0.2]); + let s0 = Series::new("days".into(), &[1, 2]); + let s1 = Series::new("rain".into(), &[0.1, 0.2]); let rain = DataFrame::new(vec![s0, s1]).unwrap(); let joined = temp.left_join(&rain, ["days"], ["days"]).unwrap(); assert_eq!( @@ -96,12 +96,12 @@ fn test_left_join() { assert_eq!(joined.column("rain").unwrap().null_count(), 3); // test join on string - let s0 = Series::new("days", &["mo", "tue", "wed", "thu", "fri"]); - let s1 = Series::new("temp", &[22.1, 19.9, 7., 2., 3.]); + let s0 = Series::new("days".into(), &["mo", "tue", "wed", "thu", "fri"]); + let s1 = Series::new("temp".into(), &[22.1, 19.9, 7., 2., 3.]); let temp = DataFrame::new(vec![s0, s1]).unwrap(); - let s0 = Series::new("days", &["tue", "wed"]); - let s1 = Series::new("rain", &[0.1, 0.2]); + let s0 = Series::new("days".into(), &["tue", "wed"]); + let s1 = Series::new("rain".into(), &[0.1, 0.2]); let rain = DataFrame::new(vec![s0, s1]).unwrap(); let joined = temp.left_join(&rain, ["days"], ["days"]).unwrap(); assert_eq!( @@ -152,12 +152,16 @@ fn test_full_outer_join() -> PolarsResult<()> { fn test_join_with_nulls() { let dts = &[20, 21, 22, 23, 24, 25, 27, 28]; let vals = &[1.2, 2.4, 4.67, 5.8, 4.4, 3.6, 7.6, 6.5]; - let df = DataFrame::new(vec![Series::new("date", dts), Series::new("val", vals)]).unwrap(); + let df = DataFrame::new(vec![ + Series::new("date".into(), dts), + Series::new("val".into(), vals), + ]) + .unwrap(); let vals2 = &[Some(1.1), None, Some(3.3), None, None]; let df2 = DataFrame::new(vec![ - Series::new("date", &dts[3..]), - Series::new("val2", vals2), + Series::new("date".into(), &dts[3..]), + Series::new("val2".into(), vals2), ]) .unwrap(); @@ -204,7 +208,7 @@ fn test_join_multiple_columns() { .str() .unwrap() + df_a.column("b").unwrap().str().unwrap(); - s.rename("dummy"); + s.rename("dummy".into()); df_a.with_column(s).unwrap(); let mut s = df_b @@ -215,7 +219,7 @@ fn test_join_multiple_columns() { .str() .unwrap() + df_b.column("bar").unwrap().str().unwrap(); - s.rename("dummy"); + s.rename("dummy".into()); df_b.with_column(s).unwrap(); let joined = df_a.left_join(&df_b, ["dummy"], ["dummy"]).unwrap(); @@ -334,14 +338,14 @@ fn test_join_categorical() { fn test_empty_df_join() -> PolarsResult<()> { let empty: Vec = vec![]; let empty_df = DataFrame::new(vec![ - Series::new("key", &empty), - Series::new("eval", &empty), + Series::new("key".into(), &empty), + Series::new("eval".into(), &empty), ]) .unwrap(); let df = DataFrame::new(vec![ - Series::new("key", &["foo"]), - Series::new("aval", &[4]), + Series::new("key".into(), &["foo"]), + Series::new("aval".into(), &[4]), ]) .unwrap(); @@ -357,8 +361,8 @@ fn test_empty_df_join() -> PolarsResult<()> { let empty: Vec = vec![]; let _empty_df = DataFrame::new(vec![ - Series::new("key", &empty), - Series::new("eval", &empty), + Series::new("key".into(), &empty), + Series::new("eval".into(), &empty), ]) .unwrap(); @@ -370,9 +374,9 @@ fn test_empty_df_join() -> PolarsResult<()> { // https://github.com/pola-rs/polars/issues/1824 let empty: Vec = vec![]; let empty_df = DataFrame::new(vec![ - Series::new("key", &empty), - Series::new("1val", &empty), - Series::new("2val", &empty), + Series::new("key".into(), &empty), + Series::new("1val".into(), &empty), + Series::new("2val".into(), &empty), ])?; let out = df.left_join(&empty_df, ["key"], ["key"])?; @@ -504,8 +508,8 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { let df_inner_join = df_left .join( &df_right, - &["col1", "join_col2"], - &["join_col1", "col2"], + ["col1", "join_col2"], + ["join_col1", "col2"], JoinType::Inner.into(), ) .unwrap(); @@ -519,8 +523,8 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { let df_left_join = df_left .join( &df_right, - &["col1", "join_col2"], - &["join_col1", "col2"], + ["col1", "join_col2"], + ["join_col1", "col2"], JoinType::Left.into(), ) .unwrap(); @@ -534,8 +538,8 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { let df_full_outer_join = df_left .join( &df_right, - &["col1", "join_col2"], - &["join_col1", "col2"], + ["col1", "join_col2"], + ["join_col1", "col2"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), ) .unwrap(); @@ -604,8 +608,8 @@ fn test_4_threads_bit_offset() -> PolarsResult<()> { let mut left_b = (0..n) .map(|i| if i % 2 == 0 { None } else { Some(0) }) .collect::(); - left_a.rename("a"); - left_b.rename("b"); + left_a.rename("a".into()); + left_b.rename("b".into()); let left_df = DataFrame::new(vec![left_a.into_series(), left_b.into_series()])?; let i = 1; @@ -615,8 +619,8 @@ fn test_4_threads_bit_offset() -> PolarsResult<()> { let mut right_b = range .map(|i| if i % 3 == 0 { None } else { Some(1) }) .collect::(); - right_a.rename("a"); - right_b.rename("b"); + right_a.rename("a".into()); + right_b.rename("b".into()); let right_df = DataFrame::new(vec![right_a.into_series(), right_b.into_series()])?; let out = JoinBuilder::new(left_df.lazy()) diff --git a/crates/polars/tests/it/core/list.rs b/crates/polars/tests/it/core/list.rs index d709a40f2be4..f485ccadd482 100644 --- a/crates/polars/tests/it/core/list.rs +++ b/crates/polars/tests/it/core/list.rs @@ -2,7 +2,7 @@ use polars::prelude::*; #[test] fn test_to_list_logical() -> PolarsResult<()> { - let ca = StringChunked::new("a", &["2021-01-01", "2021-01-02", "2021-01-03"]); + let ca = StringChunked::new("a".into(), &["2021-01-01", "2021-01-02", "2021-01-03"]); let out = ca.as_date(None, false)?.into_series(); let out = out.implode().unwrap(); assert_eq!(out.len(), 1); diff --git a/crates/polars/tests/it/core/ops/take.rs b/crates/polars/tests/it/core/ops/take.rs index 26c1bb651865..373c644da066 100644 --- a/crates/polars/tests/it/core/ops/take.rs +++ b/crates/polars/tests/it/core/ops/take.rs @@ -3,12 +3,12 @@ use super::*; #[test] fn test_list_gather_nulls_and_empty() { let a: &[i32] = &[]; - let a = Series::new("", a); - let b = Series::new("", &[None, Some(a.clone())]); + let a = Series::new("".into(), a); + let b = Series::new("".into(), &[None, Some(a.clone())]); let indices = [Some(0 as IdxSize), Some(1), None] .into_iter() - .collect_ca(""); + .collect_ca("".into()); let out = b.take(&indices).unwrap(); - let expected = Series::new("", &[None, Some(a), None]); + let expected = Series::new("".into(), &[None, Some(a), None]); assert!(out.equals_missing(&expected)) } diff --git a/crates/polars/tests/it/core/pivot.rs b/crates/polars/tests/it/core/pivot.rs index b0e1b13ca9f4..85cf69ec1494 100644 --- a/crates/polars/tests/it/core/pivot.rs +++ b/crates/polars/tests/it/core/pivot.rs @@ -56,9 +56,9 @@ fn test_pivot_date_() -> PolarsResult<()> { #[test] fn test_pivot_old() { - let s0 = Series::new("index", ["A", "A", "B", "B", "C"].as_ref()); - let s2 = Series::new("columns", ["k", "l", "m", "m", "l"].as_ref()); - let s1 = Series::new("values", [1, 2, 2, 4, 2].as_ref()); + let s0 = Series::new("index".into(), ["A", "A", "B", "B", "C"].as_ref()); + let s2 = Series::new("columns".into(), ["k", "l", "m", "m", "l"].as_ref()); + let s1 = Series::new("values".into(), [1, 2, 2, 4, 2].as_ref()); let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); let pvt = pivot( diff --git a/crates/polars/tests/it/core/rolling_window.rs b/crates/polars/tests/it/core/rolling_window.rs index b823bf7d8736..a58280e09345 100644 --- a/crates/polars/tests/it/core/rolling_window.rs +++ b/crates/polars/tests/it/core/rolling_window.rs @@ -2,7 +2,7 @@ use super::*; #[test] fn test_rolling() { - let s = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]).into_series(); + let s = Int32Chunked::new("foo".into(), &[1, 2, 3, 2, 1]).into_series(); let a = s .rolling_sum(RollingOptionsFixedWindow { window_size: 2, @@ -57,7 +57,7 @@ fn test_rolling() { #[test] fn test_rolling_min_periods() { - let s = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]).into_series(); + let s = Int32Chunked::new("foo".into(), &[1, 2, 3, 2, 1]).into_series(); let a = s .rolling_max(RollingOptionsFixedWindow { window_size: 2, @@ -72,7 +72,7 @@ fn test_rolling_min_periods() { #[test] fn test_rolling_mean() { let s = Float64Chunked::new( - "foo", + "foo".into(), &[ Some(0.0), Some(1.0), @@ -141,7 +141,7 @@ fn test_rolling_mean() { ); // integers - let ca = Int32Chunked::from_slice("", &[1, 8, 6, 2, 16, 10]); + let ca = Int32Chunked::from_slice("".into(), &[1, 8, 6, 2, 16, 10]); let out = ca .into_series() .rolling_mean(RollingOptionsFixedWindow { @@ -163,7 +163,7 @@ fn test_rolling_mean() { #[test] fn test_rolling_map() { let ca = Float64Chunked::new( - "foo", + "foo".into(), &[ Some(0.0), Some(1.0), @@ -177,7 +177,7 @@ fn test_rolling_map() { let out = ca .rolling_map( - &|s| s.sum_reduce().unwrap().into_series(s.name()), + &|s| s.sum_reduce().unwrap().into_series(s.name().clone()), RollingOptionsFixedWindow { window_size: 3, min_periods: 3, @@ -197,7 +197,7 @@ fn test_rolling_map() { #[test] fn test_rolling_var() { let s = Float64Chunked::new( - "foo", + "foo".into(), &[ Some(0.0), Some(1.0), @@ -237,7 +237,7 @@ fn test_rolling_var() { &[None, None, Some(1), None, None, None, None,] ); - let s = Float64Chunked::from_slice("", &[0.0, 2.0, 8.0, 3.0, 12.0, 1.0]).into_series(); + let s = Float64Chunked::from_slice("".into(), &[0.0, 2.0, 8.0, 3.0, 12.0, 1.0]).into_series(); let out = s .rolling_var(options) .unwrap() diff --git a/crates/polars/tests/it/core/series.rs b/crates/polars/tests/it/core/series.rs index 017609898b51..3d740ad5d940 100644 --- a/crates/polars/tests/it/core/series.rs +++ b/crates/polars/tests/it/core/series.rs @@ -3,19 +3,19 @@ use polars::series::*; #[test] fn test_series_arithmetic() -> PolarsResult<()> { - let a = &Series::new("a", &[1, 100, 6, 40]); - let b = &Series::new("b", &[-1, 2, 3, 4]); - assert_eq!((a + b)?, Series::new("a", &[0, 102, 9, 44])); - assert_eq!((a - b)?, Series::new("a", &[2, 98, 3, 36])); - assert_eq!((a * b)?, Series::new("a", &[-1, 200, 18, 160])); - assert_eq!((a / b)?, Series::new("a", &[-1, 50, 2, 10])); + let a = &Series::new("a".into(), &[1, 100, 6, 40]); + let b = &Series::new("b".into(), &[-1, 2, 3, 4]); + assert_eq!((a + b)?, Series::new("a".into(), &[0, 102, 9, 44])); + assert_eq!((a - b)?, Series::new("a".into(), &[2, 98, 3, 36])); + assert_eq!((a * b)?, Series::new("a".into(), &[-1, 200, 18, 160])); + assert_eq!((a / b)?, Series::new("a".into(), &[-1, 50, 2, 10])); Ok(()) } #[test] fn test_min_max_sorted_asc() { - let a = &mut Series::new("a", &[1, 2, 3, 4]); + let a = &mut Series::new("a".into(), &[1, 2, 3, 4]); a.set_sorted_flag(IsSorted::Ascending); assert_eq!(a.max().unwrap(), Some(4)); assert_eq!(a.min().unwrap(), Some(1)); @@ -23,7 +23,7 @@ fn test_min_max_sorted_asc() { #[test] fn test_min_max_sorted_desc() { - let a = &mut Series::new("a", &[4, 3, 2, 1]); + let a = &mut Series::new("a".into(), &[4, 3, 2, 1]); a.set_sorted_flag(IsSorted::Descending); assert_eq!(a.max().unwrap(), Some(4)); assert_eq!(a.min().unwrap(), Some(1)); @@ -31,7 +31,13 @@ fn test_min_max_sorted_desc() { #[test] fn test_construct_list_of_null_series() { - let s = Series::new("a", [Series::new_null("a1", 1), Series::new_null("a1", 1)]); + let s = Series::new( + "a".into(), + [ + Series::new_null("a1".into(), 1), + Series::new_null("a1".into(), 1), + ], + ); assert_eq!(s.null_count(), 0); assert_eq!(s.field().name(), "a"); } diff --git a/crates/polars/tests/it/io/avro/read.rs b/crates/polars/tests/it/io/avro/read.rs index 2482fb6103c7..dac9adbfc9d0 100644 --- a/crates/polars/tests/it/io/avro/read.rs +++ b/crates/polars/tests/it/io/avro/read.rs @@ -54,28 +54,32 @@ pub(super) fn schema() -> (AvroSchema, ArrowSchema) { } "#; - let schema = ArrowSchema::from(vec![ - Field::new("a", ArrowDataType::Int64, false), - Field::new("b", ArrowDataType::Utf8, false), - Field::new("c", ArrowDataType::Int32, false), - Field::new("date", ArrowDataType::Date32, false), - Field::new("d", ArrowDataType::Binary, false), - Field::new("e", ArrowDataType::Float64, false), - Field::new("f", ArrowDataType::Boolean, false), - Field::new("g", ArrowDataType::Utf8, true), + let schema = ArrowSchema::from_iter([ + Field::new("a".into(), ArrowDataType::Int64, false), + Field::new("b".into(), ArrowDataType::Utf8, false), + Field::new("c".into(), ArrowDataType::Int32, false), + Field::new("date".into(), ArrowDataType::Date32, false), + Field::new("d".into(), ArrowDataType::Binary, false), + Field::new("e".into(), ArrowDataType::Float64, false), + Field::new("f".into(), ArrowDataType::Boolean, false), + Field::new("g".into(), ArrowDataType::Utf8, true), Field::new( - "h", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + "h".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))), false, ), Field::new( - "i", - ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + "i".into(), + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), false, ), Field::new( - "nullable_struct", - ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + "nullable_struct".into(), + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), true, ), ]); @@ -105,13 +109,13 @@ pub(super) fn data() -> RecordBatchT> { Utf8Array::::from([Some("foo"), None]).boxed(), array.into_box(), StructArray::new( - ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), vec![PrimitiveArray::::from_slice([1.0, 2.0]).boxed()], None, ) .boxed(), StructArray::new( - ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + ArrowDataType::Struct(vec![Field::new("e".into(), ArrowDataType::Float64, false)]), vec![PrimitiveArray::::from_slice([1.0, 0.0]).boxed()], Some([true, false].into()), ) @@ -199,16 +203,14 @@ pub(super) fn read_avro( let metadata = read_metadata(file)?; let schema = read::infer_schema(&metadata.record)?; - let mut reader = read::Reader::new(file, metadata, schema.fields.clone(), projection.clone()); + let mut reader = read::Reader::new(file, metadata, schema.clone(), projection.clone()); let schema = if let Some(projection) = projection { - let fields = schema - .fields - .into_iter() + schema + .into_iter_values() .zip(projection.iter()) .filter_map(|x| if *x.1 { Some(x.0) } else { None }) - .collect::>(); - ArrowSchema::from(fields) + .collect() } else { schema }; @@ -250,8 +252,8 @@ fn test_projected() -> PolarsResult<()> { let avro = write_avro(Codec::Null).unwrap(); - for i in 0..expected_schema.fields.len() { - let mut projection = vec![false; expected_schema.fields.len()]; + for i in 0..expected_schema.len() { + let mut projection = vec![false; expected_schema.len()]; projection[i] = true; let expected = expected @@ -263,14 +265,12 @@ fn test_projected() -> PolarsResult<()> { .collect(); let expected = RecordBatchT::new(expected); - let expected_fields = expected_schema + let expected_schema = expected_schema .clone() - .fields - .into_iter() + .into_iter_values() .zip(projection.iter()) .filter_map(|x| if *x.1 { Some(x.0) } else { None }) - .collect::>(); - let expected_schema = ArrowSchema::from(expected_fields); + .collect(); let (result, schema) = read_avro(&avro, Some(projection))?; @@ -297,9 +297,13 @@ fn schema_list() -> (AvroSchema, ArrowSchema) { } "#; - let schema = ArrowSchema::from(vec![Field::new( - "h", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, false))), + let schema = ArrowSchema::from_iter([Field::new( + "h".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + false, + ))), false, )]); @@ -311,7 +315,11 @@ pub(super) fn data_list() -> RecordBatchT> { let mut array = MutableListArray::>::new_from( Default::default(), - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, false))), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + false, + ))), 0, ); array.try_extend(data).unwrap(); diff --git a/crates/polars/tests/it/io/avro/write.rs b/crates/polars/tests/it/io/avro/write.rs index dade870e96c6..43011eb7a2bf 100644 --- a/crates/polars/tests/it/io/avro/write.rs +++ b/crates/polars/tests/it/io/avro/write.rs @@ -15,39 +15,55 @@ use polars_error::PolarsResult; use super::read::read_avro; pub(super) fn schema() -> ArrowSchema { - ArrowSchema::from(vec![ - Field::new("int64", ArrowDataType::Int64, false), - Field::new("int64 nullable", ArrowDataType::Int64, true), - Field::new("utf8", ArrowDataType::Utf8, false), - Field::new("utf8 nullable", ArrowDataType::Utf8, true), - Field::new("int32", ArrowDataType::Int32, false), - Field::new("int32 nullable", ArrowDataType::Int32, true), - Field::new("date", ArrowDataType::Date32, false), - Field::new("date nullable", ArrowDataType::Date32, true), - Field::new("binary", ArrowDataType::Binary, false), - Field::new("binary nullable", ArrowDataType::Binary, true), - Field::new("float32", ArrowDataType::Float32, false), - Field::new("float32 nullable", ArrowDataType::Float32, true), - Field::new("float64", ArrowDataType::Float64, false), - Field::new("float64 nullable", ArrowDataType::Float64, true), - Field::new("boolean", ArrowDataType::Boolean, false), - Field::new("boolean nullable", ArrowDataType::Boolean, true), + ArrowSchema::from_iter([ + Field::new("int64".into(), ArrowDataType::Int64, false), + Field::new("int64 nullable".into(), ArrowDataType::Int64, true), + Field::new("utf8".into(), ArrowDataType::Utf8, false), + Field::new("utf8 nullable".into(), ArrowDataType::Utf8, true), + Field::new("int32".into(), ArrowDataType::Int32, false), + Field::new("int32 nullable".into(), ArrowDataType::Int32, true), + Field::new("date".into(), ArrowDataType::Date32, false), + Field::new("date nullable".into(), ArrowDataType::Date32, true), + Field::new("binary".into(), ArrowDataType::Binary, false), + Field::new("binary nullable".into(), ArrowDataType::Binary, true), + Field::new("float32".into(), ArrowDataType::Float32, false), + Field::new("float32 nullable".into(), ArrowDataType::Float32, true), + Field::new("float64".into(), ArrowDataType::Float64, false), + Field::new("float64 nullable".into(), ArrowDataType::Float64, true), + Field::new("boolean".into(), ArrowDataType::Boolean, false), + Field::new("boolean nullable".into(), ArrowDataType::Boolean, true), Field::new( - "list", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + "list".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))), false, ), Field::new( - "list nullable", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + "list nullable".into(), + ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))), true, ), ]) } pub(super) fn data() -> RecordBatchT> { - let list_dt = ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))); - let list_dt1 = ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))); + let list_dt = ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))); + let list_dt1 = ArrowDataType::List(Box::new(Field::new( + "item".into(), + ArrowDataType::Int32, + true, + ))); let columns = vec![ Box::new(Int64Array::from_slice([27, 47])) as Box, @@ -162,11 +178,15 @@ fn deflate() -> PolarsResult<()> { } fn large_format_schema() -> ArrowSchema { - ArrowSchema::from(vec![ - Field::new("large_utf8", ArrowDataType::LargeUtf8, false), - Field::new("large_utf8_nullable", ArrowDataType::LargeUtf8, true), - Field::new("large_binary", ArrowDataType::LargeBinary, false), - Field::new("large_binary_nullable", ArrowDataType::LargeBinary, true), + ArrowSchema::from_iter([ + Field::new("large_utf8".into(), ArrowDataType::LargeUtf8, false), + Field::new("large_utf8_nullable".into(), ArrowDataType::LargeUtf8, true), + Field::new("large_binary".into(), ArrowDataType::LargeBinary, false), + Field::new( + "large_binary_nullable".into(), + ArrowDataType::LargeBinary, + true, + ), ]) } @@ -181,11 +201,11 @@ fn large_format_data() -> RecordBatchT> { } fn large_format_expected_schema() -> ArrowSchema { - ArrowSchema::from(vec![ - Field::new("large_utf8", ArrowDataType::Utf8, false), - Field::new("large_utf8_nullable", ArrowDataType::Utf8, true), - Field::new("large_binary", ArrowDataType::Binary, false), - Field::new("large_binary_nullable", ArrowDataType::Binary, true), + ArrowSchema::from_iter([ + Field::new("large_utf8".into(), ArrowDataType::Utf8, false), + Field::new("large_utf8_nullable".into(), ArrowDataType::Utf8, true), + Field::new("large_binary".into(), ArrowDataType::Binary, false), + Field::new("large_binary_nullable".into(), ArrowDataType::Binary, true), ]) } @@ -219,20 +239,20 @@ fn check_large_format() -> PolarsResult<()> { } fn struct_schema() -> ArrowSchema { - ArrowSchema::from(vec![ + ArrowSchema::from_iter([ Field::new( - "struct", + "struct".into(), ArrowDataType::Struct(vec![ - Field::new("item1", ArrowDataType::Int32, false), - Field::new("item2", ArrowDataType::Int32, true), + Field::new("item1".into(), ArrowDataType::Int32, false), + Field::new("item2".into(), ArrowDataType::Int32, true), ]), false, ), Field::new( - "struct nullable", + "struct nullable".into(), ArrowDataType::Struct(vec![ - Field::new("item1", ArrowDataType::Int32, false), - Field::new("item2", ArrowDataType::Int32, true), + Field::new("item1".into(), ArrowDataType::Int32, false), + Field::new("item2".into(), ArrowDataType::Int32, true), ]), true, ), @@ -241,8 +261,8 @@ fn struct_schema() -> ArrowSchema { fn struct_data() -> RecordBatchT> { let struct_dt = ArrowDataType::Struct(vec![ - Field::new("item1", ArrowDataType::Int32, false), - Field::new("item2", ArrowDataType::Int32, true), + Field::new("item1".into(), ArrowDataType::Int32, false), + Field::new("item2".into(), ArrowDataType::Int32, true), ]); RecordBatchT::new(vec![ diff --git a/crates/polars/tests/it/io/csv.rs b/crates/polars/tests/it/io/csv.rs index 1b78969093e2..992754436c60 100644 --- a/crates/polars/tests/it/io/csv.rs +++ b/crates/polars/tests/it/io/csv.rs @@ -44,10 +44,16 @@ fn write_csv() { fn write_dates() { use polars_core::export::chrono; - let s0 = Series::new("date", [chrono::NaiveDate::from_yo_opt(2024, 33), None]); - let s1 = Series::new("time", [None, chrono::NaiveTime::from_hms_opt(19, 50, 0)]); + let s0 = Series::new( + "date".into(), + [chrono::NaiveDate::from_yo_opt(2024, 33), None], + ); + let s1 = Series::new( + "time".into(), + [None, chrono::NaiveTime::from_hms_opt(19, 50, 0)], + ); let s2 = Series::new( - "datetime", + "datetime".into(), [ Some(chrono::NaiveDateTime::new( chrono::NaiveDate::from_ymd_opt(2000, 12, 1).unwrap(), @@ -112,7 +118,7 @@ fn write_dates() { let with_timezone = polars_ops::chunked_array::replace_time_zone( s2.slice(0, 1).datetime().unwrap(), Some("America/New_York"), - &StringChunked::new("", ["raise"]), + &StringChunked::new("".into(), ["raise"]), NonExistent::Raise, ) .unwrap() @@ -214,7 +220,7 @@ fn test_parser() -> PolarsResult<()> { assert_eq!(col.get(0)?, AnyValue::String("Setosa")); assert_eq!(col.get(2)?, AnyValue::String("Setosa")); - assert_eq!("sepal_length", df.get_columns()[0].name()); + assert_eq!("sepal_length", df.get_columns()[0].name().as_str()); assert_eq!(1, df.column("sepal_length").unwrap().chunks().len()); assert_eq!(df.height(), 7); @@ -229,7 +235,7 @@ fn test_parser() -> PolarsResult<()> { .finish() .unwrap(); - assert_eq!("head_1", df.get_columns()[0].name()); + assert_eq!("head_1", df.get_columns()[0].name().as_str()); assert_eq!(df.shape(), (3, 2)); // test windows line ending with 1 byte char column and no line endings for last line. @@ -243,7 +249,7 @@ fn test_parser() -> PolarsResult<()> { .finish() .unwrap(); - assert_eq!("head_1", df.get_columns()[0].name()); + assert_eq!("head_1", df.get_columns()[0].name().as_str()); assert_eq!(df.shape(), (3, 1)); Ok(()) } @@ -303,15 +309,15 @@ fn test_missing_data() { assert!(df .column("column_1") .unwrap() - .equals(&Series::new("column_1", &[1_i64, 1]))); + .equals(&Series::new("column_1".into(), &[1_i64, 1]))); assert!(df .column("column_2") .unwrap() - .equals_missing(&Series::new("column_2", &[Some(2_i64), None]))); + .equals_missing(&Series::new("column_2".into(), &[Some(2_i64), None]))); assert!(df .column("column_3") .unwrap() - .equals(&Series::new("column_3", &[3_i64, 3]))); + .equals(&Series::new("column_3".into(), &[3_i64, 3]))); } #[test] @@ -326,7 +332,7 @@ fn test_escape_comma() { assert!(df .column("column_3") .unwrap() - .equals(&Series::new("column_3", &[11_i64, 12]))); + .equals(&Series::new("column_3".into(), &[11_i64, 12]))); } #[test] @@ -339,7 +345,7 @@ fn test_escape_double_quotes() { let df = CsvReader::new(file).finish().unwrap(); assert_eq!(df.shape(), (2, 3)); assert!(df.column("column_2").unwrap().equals(&Series::new( - "column_2", + "column_2".into(), &[ r#"with "double quotes" US"#, r#"with "double quotes followed", by comma"# @@ -387,7 +393,7 @@ hello,","," ",world,"!" .finish() .unwrap(); - for (col, val) in &[ + for (col, val) in [ ("column_1", "hello"), ("column_2", ","), ("column_3", " "), @@ -397,7 +403,7 @@ hello,","," ",world,"!" assert!(df .column(col) .unwrap() - .equals(&Series::new(col, &[&**val; 4]))); + .equals(&Series::new(col.into(), &[val; 4]))); } } @@ -420,7 +426,7 @@ versions of Lorem Ipsum.",11 .unwrap(); assert!(df.column("column_2").unwrap().equals(&Series::new( - "column_2", + "column_2".into(), &[ r#"Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since th @@ -508,13 +514,13 @@ fn test_quoted_numeric() { #[test] fn test_empty_bytes_to_dataframe() { - let fields = vec![Field::new("test_field", DataType::String)]; + let fields = vec![Field::new("test_field".into(), DataType::String)]; let schema = Schema::from_iter(fields); let file = Cursor::new(vec![]); let result = CsvReadOptions::default() .with_has_header(false) - .with_columns(Some(schema.iter_names().map(|s| s.to_string()).collect())) + .with_columns(Some(schema.iter_names_cloned().collect())) .with_schema(Some(Arc::new(schema))) .into_reader_with_file_handle(file) .finish(); @@ -548,9 +554,9 @@ fn test_missing_value() { let df = CsvReadOptions::default() .with_has_header(true) .with_schema(Some(Arc::new(Schema::from_iter([ - Field::new("foo", DataType::UInt32), - Field::new("bar", DataType::UInt32), - Field::new("ham", DataType::UInt32), + Field::new("foo".into(), DataType::UInt32), + Field::new("bar".into(), DataType::UInt32), + Field::new("ham".into(), DataType::UInt32), ])))) .into_reader_with_file_handle(file) .finish() @@ -571,7 +577,7 @@ AUDCAD,1616455921,0.96212,0.95666,1 let df = CsvReadOptions::default() .with_has_header(true) .with_schema_overwrite(Some(Arc::new(Schema::from_iter([Field::new( - "b", + "b".into(), DataType::Datetime(TimeUnit::Nanoseconds, None), )])))) .with_ignore_errors(true) @@ -730,8 +736,7 @@ null-value,b,bar let file = Cursor::new(csv); let df = CsvReadOptions::default() .map_parse_options(|parse_options| { - parse_options - .with_null_values(Some(NullValues::AllColumnsSingle("null-value".to_string()))) + parse_options.with_null_values(Some(NullValues::AllColumnsSingle("null-value".into()))) }) .into_reader_with_file_handle(file) .finish()?; diff --git a/crates/polars/tests/it/io/ipc.rs b/crates/polars/tests/it/io/ipc.rs index 6b5e2a83ba41..8a5602c86051 100644 --- a/crates/polars/tests/it/io/ipc.rs +++ b/crates/polars/tests/it/io/ipc.rs @@ -24,8 +24,8 @@ fn test_ipc_compression_variadic_buffers() { #[cfg(test)] pub(crate) fn create_df() -> DataFrame { - let s0 = Series::new("days", [0, 1, 2, 3, 4].as_ref()); - let s1 = Series::new("temp", [22.1, 19.9, 7., 2., 3.].as_ref()); + let s0 = Series::new("days".into(), [0, 1, 2, 3, 4].as_ref()); + let s1 = Series::new("temp".into(), [22.1, 19.9, 7., 2., 3.].as_ref()); DataFrame::new(vec![s0, s1]).unwrap() } @@ -140,7 +140,7 @@ fn test_write_with_compression() { #[test] fn write_and_read_ipc_empty_series() { let mut buf: Cursor> = Cursor::new(Vec::new()); - let chunked_array = Float64Chunked::new("empty", &[0_f64; 0]); + let chunked_array = Float64Chunked::new("empty".into(), &[0_f64; 0]); let mut df = DataFrame::new(vec![chunked_array.into_series()]).unwrap(); IpcWriter::new(&mut buf) .finish(&mut df) diff --git a/crates/polars/tests/it/io/ipc_stream.rs b/crates/polars/tests/it/io/ipc_stream.rs index 18d67990cb53..d12082d0dd71 100644 --- a/crates/polars/tests/it/io/ipc_stream.rs +++ b/crates/polars/tests/it/io/ipc_stream.rs @@ -145,7 +145,10 @@ mod test { #[test] fn write_and_read_ipc_stream_empty_series() { fn df() -> DataFrame { - DataFrame::new(vec![Float64Chunked::new("empty", &[0_f64; 0]).into_series()]).unwrap() + DataFrame::new(vec![ + Float64Chunked::new("empty".into(), &[0_f64; 0]).into_series() + ]) + .unwrap() } let reader = create_ipc_stream(df()); diff --git a/crates/polars/tests/it/io/json.rs b/crates/polars/tests/it/io/json.rs index faf17d71d07e..9095d4299bdd 100644 --- a/crates/polars/tests/it/io/json.rs +++ b/crates/polars/tests/it/io/json.rs @@ -25,8 +25,8 @@ fn read_json() { .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() .unwrap(); - assert_eq!("a", df.get_columns()[0].name()); - assert_eq!("d", df.get_columns()[3].name()); + assert_eq!("a", df.get_columns()[0].name().as_str()); + assert_eq!("d", df.get_columns()[3].name().as_str()); assert_eq!((12, 4), df.shape()); } #[test] @@ -53,8 +53,8 @@ fn read_json_with_whitespace() { .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() .unwrap(); - assert_eq!("a", df.get_columns()[0].name()); - assert_eq!("d", df.get_columns()[3].name()); + assert_eq!("a", df.get_columns()[0].name().as_str()); + assert_eq!("d", df.get_columns()[3].name().as_str()); assert_eq!((12, 4), df.shape()); } #[test] @@ -76,12 +76,12 @@ fn read_json_with_escapes() { .infer_schema_len(NonZeroUsize::new(6)) .finish() .unwrap(); - assert_eq!("id", df.get_columns()[0].name()); + assert_eq!("id", df.get_columns()[0].name().as_str()); assert_eq!( AnyValue::String("\""), df.column("text").unwrap().get(0).unwrap() ); - assert_eq!("text", df.get_columns()[1].name()); + assert_eq!("text", df.get_columns()[1].name().as_str()); assert_eq!((10, 3), df.shape()); } @@ -107,8 +107,8 @@ fn read_unordered_json() { .with_batch_size(NonZeroUsize::new(3).unwrap()) .finish() .unwrap(); - assert_eq!("a", df.get_columns()[0].name()); - assert_eq!("d", df.get_columns()[3].name()); + assert_eq!("a", df.get_columns()[0].name().as_str()); + assert_eq!("d", df.get_columns()[3].name().as_str()); assert_eq!((12, 4), df.shape()); } @@ -141,11 +141,17 @@ fn test_read_ndjson_iss_5875() { let df = JsonLineReader::new(cursor).finish(); assert!(df.is_ok()); - let field_int_inner = Field::new("int_inner", DataType::List(Box::new(DataType::Int64))); - let field_float_inner = Field::new("float_inner", DataType::Float64); - let field_str_inner = Field::new("str_inner", DataType::List(Box::new(DataType::String))); + let field_int_inner = Field::new( + "int_inner".into(), + DataType::List(Box::new(DataType::Int64)), + ); + let field_float_inner = Field::new("float_inner".into(), DataType::Float64); + let field_str_inner = Field::new( + "str_inner".into(), + DataType::List(Box::new(DataType::String)), + ); - let mut schema = Schema::new(); + let mut schema = Schema::default(); schema.with_column( "struct".into(), DataType::Struct(vec![field_int_inner, field_float_inner, field_str_inner]), diff --git a/crates/polars/tests/it/io/mod.rs b/crates/polars/tests/it/io/mod.rs index 4835171721c9..2fd9aab899d1 100644 --- a/crates/polars/tests/it/io/mod.rs +++ b/crates/polars/tests/it/io/mod.rs @@ -17,7 +17,7 @@ mod ipc_stream; use polars::prelude::*; pub(crate) fn create_df() -> DataFrame { - let s0 = Series::new("days", [0, 1, 2, 3, 4].as_ref()); - let s1 = Series::new("temp", [22.1, 19.9, 7., 2., 3.].as_ref()); + let s0 = Series::new("days".into(), [0, 1, 2, 3, 4].as_ref()); + let s1 = Series::new("temp".into(), [22.1, 19.9, 7., 2., 3.].as_ref()); DataFrame::new(vec![s0, s1]).unwrap() } diff --git a/crates/polars/tests/it/io/parquet/arrow/mod.rs b/crates/polars/tests/it/io/parquet/arrow/mod.rs index 477b9a1321df..a54b5fcacb1c 100644 --- a/crates/polars/tests/it/io/parquet/arrow/mod.rs +++ b/crates/polars/tests/it/io/parquet/arrow/mod.rs @@ -15,7 +15,7 @@ use polars_parquet::read as p_read; use polars_parquet::read::statistics::*; use polars_parquet::write::*; -type ArrayStats = (Box, Statistics); +use super::read::file::FileReader; fn new_struct( arrays: Vec>, @@ -25,38 +25,22 @@ fn new_struct( let fields = names .into_iter() .zip(arrays.iter()) - .map(|(n, a)| Field::new(n, a.data_type().clone(), true)) + .map(|(n, a)| Field::new(n.into(), a.dtype().clone(), true)) .collect(); StructArray::new(ArrowDataType::Struct(fields), arrays, validity) } -pub fn read_column(mut reader: R, column: &str) -> PolarsResult { +pub fn read_column(mut reader: R, column: &str) -> PolarsResult> { let metadata = p_read::read_metadata(&mut reader)?; let schema = p_read::infer_schema(&metadata)?; - let row_group = &metadata.row_groups[0]; - - // verify that we can read indexes - if p_read::indexes::has_indexes(row_group) { - let _indexes = p_read::indexes::read_filtered_pages( - &mut reader, - row_group, - &schema.fields, - |_, _| vec![], - )?; - } - let schema = schema.filter(|_, f| f.name == column); - let field = &schema.fields[0]; - - let statistics = deserialize(field, row_group)?; - - let mut reader = p_read::FileReader::new(reader, metadata.row_groups, schema, None); + let mut reader = FileReader::new(reader, metadata.row_groups, schema, None); let array = reader.next().unwrap()?.into_arrays().pop().unwrap(); - Ok((array, statistics)) + Ok(array) } pub fn pyarrow_nested_edge(column: &str) -> Box { @@ -91,7 +75,7 @@ pub fn pyarrow_nested_edge(column: &str) -> Box { // ] let a = ListArray::::new( ArrowDataType::LargeList(Box::new(Field::new( - "item", + "item".into(), ArrowDataType::Utf8View, true, ))), @@ -100,7 +84,7 @@ pub fn pyarrow_nested_edge(column: &str) -> Box { None, ); StructArray::new( - ArrowDataType::Struct(vec![Field::new("f1", a.data_type().clone(), true)]), + ArrowDataType::Struct(vec![Field::new("f1".into(), a.dtype().clone(), true)]), vec![a.boxed()], None, ) @@ -110,8 +94,8 @@ pub fn pyarrow_nested_edge(column: &str) -> Box { let values = pyarrow_nested_edge("struct_list_nullable"); ListArray::::new( ArrowDataType::LargeList(Box::new(Field::new( - "item", - values.data_type().clone(), + "item".into(), + values.dtype().clone(), true, ))), vec![0, 1].try_into().unwrap(), @@ -321,8 +305,8 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { let array = ListArray::::new( ArrowDataType::LargeList(Box::new(Field::new( - "item", - array.data_type().clone(), + "item".into(), + array.dtype().clone(), true, ))), vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] @@ -360,15 +344,21 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { match column { "list_int64_required_required" => { // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] - let data_type = - ArrowDataType::LargeList(Box::new(Field::new("item", ArrowDataType::Int64, false))); - ListArray::::new(data_type, offsets, values, None).boxed() + let dtype = ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + ArrowDataType::Int64, + false, + ))); + ListArray::::new(dtype, offsets, values, None).boxed() }, "list_int64_optional_required" => { // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] - let data_type = - ArrowDataType::LargeList(Box::new(Field::new("item", ArrowDataType::Int64, true))); - ListArray::::new(data_type, offsets, values, None).boxed() + let dtype = ArrowDataType::LargeList(Box::new(Field::new( + "item".into(), + ArrowDataType::Int64, + true, + ))); + ListArray::::new(dtype, offsets, values, None).boxed() }, "list_nested_i64" => { // [[0, 1]], None, [[2, None], [3]], [[4, 5], [6]], [], [[7], None, [9]], [[], [None], None], [[10]] @@ -429,16 +419,20 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { "struct_list_nullable" => new_struct(vec![values], vec!["a".to_string()], None).boxed(), _ => { let field = match column { - "list_int64" => Field::new("item", ArrowDataType::Int64, true), - "list_int64_required" => Field::new("item", ArrowDataType::Int64, false), - "list_int16" => Field::new("item", ArrowDataType::Int16, true), - "list_bool" => Field::new("item", ArrowDataType::Boolean, true), - "list_utf8" => Field::new("item", ArrowDataType::Utf8View, true), - "list_large_binary" => Field::new("item", ArrowDataType::LargeBinary, true), - "list_decimal" => Field::new("item", ArrowDataType::Decimal(9, 0), true), - "list_decimal256" => Field::new("item", ArrowDataType::Decimal256(9, 0), true), - "list_struct_nullable" => Field::new("item", values.data_type().clone(), true), - "list_struct_list_nullable" => Field::new("item", values.data_type().clone(), true), + "list_int64" => Field::new("item".into(), ArrowDataType::Int64, true), + "list_int64_required" => Field::new("item".into(), ArrowDataType::Int64, false), + "list_int16" => Field::new("item".into(), ArrowDataType::Int16, true), + "list_bool" => Field::new("item".into(), ArrowDataType::Boolean, true), + "list_utf8" => Field::new("item".into(), ArrowDataType::Utf8View, true), + "list_large_binary" => Field::new("item".into(), ArrowDataType::LargeBinary, true), + "list_decimal" => Field::new("item".into(), ArrowDataType::Decimal(9, 0), true), + "list_decimal256" => { + Field::new("item".into(), ArrowDataType::Decimal256(9, 0), true) + }, + "list_struct_nullable" => Field::new("item".into(), values.dtype().clone(), true), + "list_struct_list_nullable" => { + Field::new("item".into(), values.dtype().clone(), true) + }, other => unreachable!("{}", other), }; @@ -447,8 +441,8 @@ pub fn pyarrow_nested_nullable(column: &str) -> Box { ])); // [0, 2, 2, 5, 8, 8, 11, 11, 12] // [[a1, a2], None, [a3, a4, a5], [a6, a7, a8], [], [a9, a10, a11], None, [a12]] - let data_type = ArrowDataType::LargeList(Box::new(field)); - ListArray::::new(data_type, offsets, values, validity).boxed() + let dtype = ArrowDataType::LargeList(Box::new(field)); + ListArray::::new(dtype, offsets, values, validity).boxed() }, } } @@ -536,7 +530,7 @@ pub fn pyarrow_nullable(column: &str) -> Box { .to(ArrowDataType::Timestamp(TimeUnit::Second, None)), ), "timestamp_s_utc" => Box::new(PrimitiveArray::::from(i64_values).to( - ArrowDataType::Timestamp(TimeUnit::Second, Some("UTC".to_string())), + ArrowDataType::Timestamp(TimeUnit::Second, Some("UTC".into())), )), _ => unreachable!(), } @@ -625,11 +619,11 @@ pub fn pyarrow_nullable_statistics(column: &str) -> Statistics { null_count: UInt64Array::from([Some(3)]).boxed(), min_value: Box::new(Int64Array::from_slice([-256]).to(ArrowDataType::Timestamp( TimeUnit::Second, - Some("UTC".to_string()), + Some("UTC".into()), ))), max_value: Box::new(Int64Array::from_slice([9]).to(ArrowDataType::Timestamp( TimeUnit::Second, - Some("UTC".to_string()), + Some("UTC".into()), ))), }, _ => unreachable!(), @@ -682,8 +676,8 @@ pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { let new_list = |array: Box, nullable: bool| { ListArray::::new( ArrowDataType::LargeList(Box::new(Field::new( - "item", - array.data_type().clone(), + "item".into(), + array.dtype().clone(), nullable, ))), vec![0, array.len() as i64].try_into().unwrap(), @@ -929,8 +923,8 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { let new_list = |array: Box| { ListArray::::new( ArrowDataType::LargeList(Box::new(Field::new( - "item", - array.data_type().clone(), + "item".into(), + array.dtype().clone(), true, ))), vec![0, array.len() as i64].try_into().unwrap(), @@ -943,7 +937,7 @@ pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { let fields = names .into_iter() .zip(arrays.iter()) - .map(|(n, a)| Field::new(n, a.data_type().clone(), true)) + .map(|(n, a)| Field::new(n.into(), a.dtype().clone(), true)) .collect(); StructArray::new(ArrowDataType::Struct(fields), arrays, None) }; @@ -1047,8 +1041,8 @@ pub fn pyarrow_struct(column: &str) -> Box { let mask = [true, true, false, true, true, true, true, true, true, true]; let fields = vec![ - Field::new("f1", ArrowDataType::Utf8View, true), - Field::new("f2", ArrowDataType::Boolean, true), + Field::new("f1".into(), ArrowDataType::Utf8View, true), + Field::new("f2".into(), ArrowDataType::Boolean, true), ]; match column { "struct" => { @@ -1062,8 +1056,8 @@ pub fn pyarrow_struct(column: &str) -> Box { let struct_ = pyarrow_struct("struct"); Box::new(StructArray::new( ArrowDataType::Struct(vec![ - Field::new("f1", ArrowDataType::Struct(fields), true), - Field::new("f2", ArrowDataType::Boolean, true), + Field::new("f1".into(), ArrowDataType::Struct(fields), true), + Field::new("f2".into(), ArrowDataType::Boolean, true), ]), vec![struct_, boolean], None, @@ -1073,8 +1067,8 @@ pub fn pyarrow_struct(column: &str) -> Box { let struct_ = pyarrow_struct("struct"); Box::new(StructArray::new( ArrowDataType::Struct(vec![ - Field::new("f1", ArrowDataType::Struct(fields), true), - Field::new("f2", ArrowDataType::Boolean, true), + Field::new("f1".into(), ArrowDataType::Struct(fields), true), + Field::new("f2".into(), ArrowDataType::Boolean, true), ]), vec![struct_, boolean], Some(mask.into()), @@ -1263,10 +1257,9 @@ fn integration_write( }; let encodings = schema - .fields - .iter() + .iter_values() .map(|f| { - transverse(&f.data_type, |x| { + transverse(&f.dtype, |x| { if let ArrowDataType::Dictionary(..) = x { Encoding::RleDictionary } else { @@ -1298,11 +1291,7 @@ fn integration_read(data: &[u8], limit: Option) -> PolarsResult PolarsResult<(ArrowSchema, RecordBatchT>)> { let values = PrimitiveArray::from_slice([1i64, 3]) .to(ArrowDataType::Timestamp( TimeUnit::Millisecond, - Some("UTC".to_string()), + Some("UTC".into()), )) .boxed(); let array7 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); @@ -1354,19 +1343,19 @@ fn generic_data() -> PolarsResult<(ArrowSchema, RecordBatchT>)> { let array13 = PrimitiveArray::::from_slice([1, 2, 3]) .to(ArrowDataType::Interval(IntervalUnit::YearMonth)); - let schema = ArrowSchema::from(vec![ - Field::new("a1", array1.data_type().clone(), true), - Field::new("a2", array2.data_type().clone(), true), - Field::new("a3", array3.data_type().clone(), true), - Field::new("a4", array4.data_type().clone(), true), - Field::new("a6", array6.data_type().clone(), true), - Field::new("a7", array7.data_type().clone(), true), - Field::new("a8", array8.data_type().clone(), true), - Field::new("a9", array9.data_type().clone(), true), - Field::new("a10", array10.data_type().clone(), true), - Field::new("a11", array11.data_type().clone(), true), - Field::new("a12", array12.data_type().clone(), true), - Field::new("a13", array13.data_type().clone(), true), + let schema = ArrowSchema::from_iter([ + Field::new("a1".into(), array1.dtype().clone(), true), + Field::new("a2".into(), array2.dtype().clone(), true), + Field::new("a3".into(), array3.dtype().clone(), true), + Field::new("a4".into(), array4.dtype().clone(), true), + Field::new("a6".into(), array6.dtype().clone(), true), + Field::new("a7".into(), array7.dtype().clone(), true), + Field::new("a8".into(), array8.dtype().clone(), true), + Field::new("a9".into(), array9.dtype().clone(), true), + Field::new("a10".into(), array10.dtype().clone(), true), + Field::new("a11".into(), array11.dtype().clone(), true), + Field::new("a12".into(), array12.dtype().clone(), true), + Field::new("a13".into(), array13.dtype().clone(), true), ]); let chunk = RecordBatchT::try_new(vec![ array1.boxed(), @@ -1448,7 +1437,7 @@ fn data>( ]; let mut array = MutableListArray::::new_with_field( MutablePrimitiveArray::::new(), - "item", + "item".into(), inner_is_nullable, ); array.try_extend(data).unwrap(); @@ -1460,11 +1449,8 @@ fn assert_array_roundtrip( array: Box, limit: Option, ) -> PolarsResult<()> { - let schema = ArrowSchema::from(vec![Field::new( - "a1", - array.data_type().clone(), - is_nullable, - )]); + let schema = + ArrowSchema::from_iter([Field::new("a1".into(), array.dtype().clone(), is_nullable)]); let chunk = RecordBatchT::try_new(vec![array])?; assert_roundtrip(schema, chunk, limit) @@ -1516,7 +1502,7 @@ fn list_slice() -> PolarsResult<()> { ]; let mut array = MutableListArray::::new_with_field( MutablePrimitiveArray::::new(), - "item", + "item".into(), true, ); array.try_extend(data).unwrap(); @@ -1553,7 +1539,7 @@ fn list_int_nullable() -> PolarsResult<()> { ]; let mut array = MutableListArray::::new_with_field( MutablePrimitiveArray::::new(), - "item", + "item".into(), true, ); array.try_extend(data).unwrap(); @@ -1572,9 +1558,9 @@ fn limit_list() -> PolarsResult<()> { } fn nested_dict_data( - data_type: ArrowDataType, + dtype: ArrowDataType, ) -> PolarsResult<(ArrowSchema, RecordBatchT>)> { - let values = match data_type { + let values = match dtype { ArrowDataType::Float32 => PrimitiveArray::from_slice([1.0f32, 3.0]).boxed(), ArrowDataType::Utf8View => Utf8ViewArray::from_slice([Some("a"), Some("b")]).boxed(), _ => unreachable!(), @@ -1584,8 +1570,8 @@ fn nested_dict_data( let values = DictionaryArray::try_from_keys(indices, values).unwrap(); let values = LargeListArray::try_new( ArrowDataType::LargeList(Box::new(Field::new( - "item", - values.data_type().clone(), + "item".into(), + values.dtype().clone(), false, ))), vec![0i64, 0, 0, 2, 3].try_into().unwrap(), @@ -1593,7 +1579,7 @@ fn nested_dict_data( Some([true, false, true, true].into()), )?; - let schema = ArrowSchema::from(vec![Field::new("c1", values.data_type().clone(), true)]); + let schema = ArrowSchema::from_iter([Field::new("c1".into(), values.dtype().clone(), true)]); let chunk = RecordBatchT::try_new(vec![values.boxed()])?; Ok((schema, chunk)) @@ -1624,7 +1610,7 @@ fn nested_dict_limit() -> PolarsResult<()> { fn filter_chunk() -> PolarsResult<()> { let chunk1 = RecordBatchT::new(vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); let chunk2 = RecordBatchT::new(vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); - let schema = ArrowSchema::from(vec![Field::new("c1", ArrowDataType::Int16, true)]); + let schema = ArrowSchema::from_iter([Field::new("c1".into(), ArrowDataType::Int16, true)]); let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; @@ -1644,7 +1630,7 @@ fn filter_chunk() -> PolarsResult<()> { .map(|(_, row_group)| row_group) .collect(); - let reader = p_read::FileReader::new(reader, row_groups, schema, None); + let reader = FileReader::new(reader, row_groups, schema, None); let new_chunks = reader.collect::>>()?; diff --git a/crates/polars/tests/it/io/parquet/arrow/read.rs b/crates/polars/tests/it/io/parquet/arrow/read.rs index 6aaeb8c297cb..ac03b7fb1e10 100644 --- a/crates/polars/tests/it/io/parquet/arrow/read.rs +++ b/crates/polars/tests/it/io/parquet/arrow/read.rs @@ -3,9 +3,12 @@ use std::path::PathBuf; use polars_parquet::arrow::read::*; use super::*; +use crate::io::parquet::read::file::FileReader; #[cfg(feature = "parquet")] #[test] fn all_types() -> PolarsResult<()> { + use crate::io::parquet::read::file::FileReader; + let dir = env!("CARGO_MANIFEST_DIR"); let path = PathBuf::from(dir).join("../../docs/data/alltypes_plain.parquet"); @@ -49,6 +52,8 @@ fn all_types() -> PolarsResult<()> { #[test] fn all_types_chunked() -> PolarsResult<()> { // this has one batch with 8 elements + + use crate::io::parquet::read::file::FileReader; let dir = env!("CARGO_MANIFEST_DIR"); let path = PathBuf::from(dir).join("../../docs/data/alltypes_plain.parquet"); let mut reader = std::fs::File::open(path)?; @@ -92,8 +97,6 @@ fn all_types_chunked() -> PolarsResult<()> { #[test] fn read_int96_timestamps() -> PolarsResult<()> { - use std::collections::BTreeMap; - let timestamp_data = &[ 0x50, 0x41, 0x52, 0x31, 0x15, 0x04, 0x15, 0x48, 0x15, 0x3c, 0x4c, 0x15, 0x06, 0x15, 0x00, 0x12, 0x00, 0x00, 0x24, 0x00, 0x00, 0x0d, 0x01, 0x08, 0x9f, 0xd5, 0x1f, 0x0d, 0x0a, 0x44, @@ -120,14 +123,11 @@ fn read_int96_timestamps() -> PolarsResult<()> { let parse = |time_unit: TimeUnit| { let mut reader = Cursor::new(timestamp_data); let metadata = read_metadata(&mut reader)?; - let schema = arrow::datatypes::ArrowSchema { - fields: vec![arrow::datatypes::Field::new( - "timestamps", - arrow::datatypes::ArrowDataType::Timestamp(time_unit, None), - false, - )], - metadata: BTreeMap::new(), - }; + let schema = arrow::datatypes::ArrowSchema::from_iter([arrow::datatypes::Field::new( + "timestamps".into(), + arrow::datatypes::ArrowDataType::Timestamp(time_unit, None), + false, + )]); let reader = FileReader::new(reader, metadata.row_groups, schema, None); reader.collect::>>() }; diff --git a/crates/polars/tests/it/io/parquet/arrow/write.rs b/crates/polars/tests/it/io/parquet/arrow/write.rs index 9c25f346c2e1..8863c068baff 100644 --- a/crates/polars/tests/it/io/parquet/arrow/write.rs +++ b/crates/polars/tests/it/io/parquet/arrow/write.rs @@ -9,7 +9,7 @@ fn round_trip( compression: CompressionOptions, encodings: Vec, ) -> PolarsResult<()> { - round_trip_opt_stats(column, file, version, compression, encodings, true) + round_trip_opt_stats(column, file, version, compression, encodings) } fn round_trip_opt_stats( @@ -18,9 +18,8 @@ fn round_trip_opt_stats( version: Version, compression: CompressionOptions, encodings: Vec, - check_stats: bool, ) -> PolarsResult<()> { - let (array, statistics) = match file { + let (array, _statistics) = match file { "nested" => ( pyarrow_nested_nullable(column), pyarrow_nested_nullable_statistics(column), @@ -41,8 +40,8 @@ fn round_trip_opt_stats( _ => unreachable!(), }; - let field = Field::new("a1", array.data_type().clone(), true); - let schema = ArrowSchema::from(vec![field]); + let field = Field::new("a1".into(), array.dtype().clone(), true); + let schema = ArrowSchema::from_iter([field]); let options = WriteOptions { statistics: StatisticsOptions::full(), @@ -68,12 +67,9 @@ fn round_trip_opt_stats( std::fs::write("list_struct_list_nullable.parquet", &data).unwrap(); - let (result, stats) = read_column(&mut Cursor::new(data), "a1")?; + let result = read_column(&mut Cursor::new(data), "a1")?; assert_eq!(array.as_ref(), result.as_ref()); - if check_stats { - assert_eq!(statistics, stats); - } Ok(()) } @@ -364,7 +360,6 @@ fn list_nested_inner_required_required_i64() -> PolarsResult<()> { Version::V1, CompressionOptions::Uncompressed, vec![Encoding::Plain], - false, ) } @@ -376,7 +371,6 @@ fn v1_nested_struct_list_nullable() -> PolarsResult<()> { Version::V1, CompressionOptions::Uncompressed, vec![Encoding::Plain], - true, ) } @@ -388,7 +382,6 @@ fn v1_nested_list_struct_list_nullable() -> PolarsResult<()> { Version::V1, CompressionOptions::Uncompressed, vec![Encoding::Plain], - true, ) } diff --git a/crates/polars/tests/it/io/parquet/mod.rs b/crates/polars/tests/it/io/parquet/mod.rs index 5cc89e7452d7..5d088aab3b15 100644 --- a/crates/polars/tests/it/io/parquet/mod.rs +++ b/crates/polars/tests/it/io/parquet/mod.rs @@ -1,6 +1,6 @@ #![forbid(unsafe_code)] mod arrow; -mod read; +pub(crate) mod read; mod roundtrip; mod write; @@ -113,7 +113,7 @@ pub fn alltypes_plain(column: &str) -> Array { pub fn alltypes_statistics(column: &str) -> Statistics { match column { "id" => PrimitiveStatistics:: { - primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int32), null_count: Some(0), distinct_count: None, min_value: Some(0), @@ -121,7 +121,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "id-short-array" => PrimitiveStatistics:: { - primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int32), null_count: Some(0), distinct_count: None, min_value: Some(4), @@ -136,7 +136,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "tinyint_col" | "smallint_col" | "int_col" => PrimitiveStatistics:: { - primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int32), null_count: Some(0), distinct_count: None, min_value: Some(0), @@ -144,7 +144,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "bigint_col" => PrimitiveStatistics:: { - primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int64), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Int64), null_count: Some(0), distinct_count: None, min_value: Some(0), @@ -152,7 +152,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "float_col" => PrimitiveStatistics:: { - primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Float), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Float), null_count: Some(0), distinct_count: None, min_value: Some(0.0), @@ -160,7 +160,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "double_col" => PrimitiveStatistics:: { - primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Double), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::Double), null_count: Some(0), distinct_count: None, min_value: Some(0.0), @@ -168,10 +168,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "date_string_col" => BinaryStatistics { - primitive_type: PrimitiveType::from_physical( - "col".to_string(), - PhysicalType::ByteArray, - ), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::ByteArray), null_count: Some(0), distinct_count: None, min_value: Some(vec![48, 49, 47, 48, 49, 47, 48, 57]), @@ -179,10 +176,7 @@ pub fn alltypes_statistics(column: &str) -> Statistics { } .into(), "string_col" => BinaryStatistics { - primitive_type: PrimitiveType::from_physical( - "col".to_string(), - PhysicalType::ByteArray, - ), + primitive_type: PrimitiveType::from_physical("col".into(), PhysicalType::ByteArray), null_count: Some(0), distinct_count: None, min_value: Some(vec![48]), diff --git a/crates/polars-parquet/src/arrow/read/file.rs b/crates/polars/tests/it/io/parquet/read/file.rs similarity index 93% rename from crates/polars-parquet/src/arrow/read/file.rs rename to crates/polars/tests/it/io/parquet/read/file.rs index a390022331be..5007dcdf0755 100644 --- a/crates/polars-parquet/src/arrow/read/file.rs +++ b/crates/polars/tests/it/io/parquet/read/file.rs @@ -4,10 +4,9 @@ use arrow::array::Array; use arrow::datatypes::ArrowSchema; use arrow::record_batch::RecordBatchT; use polars_error::PolarsResult; +use polars_parquet::read::{Filter, RowGroupMetaData}; -use super::deserialize::Filter; -use super::{RowGroupDeserializer, RowGroupMetaData}; -use crate::arrow::read::read_columns_many; +use super::row_group::{read_columns_many, RowGroupDeserializer}; /// An iterator of [`RecordBatchT`]s coming from row groups of a parquet file. /// @@ -53,11 +52,6 @@ impl FileReader { } Ok(result) } - - /// Returns the [`ArrowSchema`] associated to this file. - pub fn schema(&self) -> &ArrowSchema { - &self.row_groups.schema - } } impl Iterator for FileReader { @@ -132,7 +126,7 @@ impl RowGroupReader { #[inline] fn _next(&mut self) -> PolarsResult> { - if self.schema.fields.is_empty() { + if self.schema.is_empty() { return Ok(None); } if self.remaining_rows == 0 { @@ -151,7 +145,7 @@ impl RowGroupReader { let column_chunks = read_columns_many( &mut self.reader, &row_group, - self.schema.fields.clone(), + &self.schema, Some(Filter::new_limited(self.remaining_rows)), )?; diff --git a/crates/polars/tests/it/io/parquet/read/indexes.rs b/crates/polars/tests/it/io/parquet/read/indexes.rs deleted file mode 100644 index e55c8b37a474..000000000000 --- a/crates/polars/tests/it/io/parquet/read/indexes.rs +++ /dev/null @@ -1,143 +0,0 @@ -use polars_parquet::parquet::error::ParquetError; -use polars_parquet::parquet::indexes::{ - BooleanIndex, BoundaryOrder, ByteIndex, Index, NativeIndex, PageIndex, PageLocation, -}; -use polars_parquet::parquet::read::{read_columns_indexes, read_metadata, read_pages_locations}; -use polars_parquet::parquet::schema::types::{ - FieldInfo, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, PrimitiveType, -}; -use polars_parquet::parquet::schema::Repetition; - -/* -import pyspark.sql # 3.2.1 -spark = pyspark.sql.SparkSession.builder.getOrCreate() -spark.conf.set("parquet.bloom.filter.enabled", True) -spark.conf.set("parquet.bloom.filter.expected.ndv", 10) -spark.conf.set("parquet.bloom.filter.max.bytes", 32) - -data = [(i, f"{i}", False) for i in range(10)] -df = spark.createDataFrame(data, ["id", "string", "bool"]).repartition(1) - -df.write.parquet("bla.parquet", mode = "overwrite") -*/ -const FILE: &[u8] = &[ - 80, 65, 82, 49, 21, 0, 21, 172, 1, 21, 138, 1, 21, 169, 161, 209, 137, 5, 28, 21, 20, 21, 0, - 21, 6, 21, 8, 0, 0, 86, 24, 2, 0, 0, 0, 20, 1, 0, 13, 1, 17, 9, 1, 22, 1, 1, 0, 3, 1, 5, 12, 0, - 0, 0, 4, 1, 5, 12, 0, 0, 0, 5, 1, 5, 12, 0, 0, 0, 6, 1, 5, 12, 0, 0, 0, 7, 1, 5, 72, 0, 0, 0, - 8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 21, 0, 21, 112, 21, 104, 21, 138, 239, 232, - 170, 15, 28, 21, 20, 21, 0, 21, 6, 21, 8, 0, 0, 56, 40, 2, 0, 0, 0, 20, 1, 1, 0, 0, 0, 48, 1, - 5, 0, 49, 1, 5, 0, 50, 1, 5, 0, 51, 1, 5, 0, 52, 1, 5, 0, 53, 1, 5, 60, 54, 1, 0, 0, 0, 55, 1, - 0, 0, 0, 56, 1, 0, 0, 0, 57, 21, 0, 21, 16, 21, 20, 21, 202, 209, 169, 227, 4, 28, 21, 20, 21, - 0, 21, 6, 21, 8, 0, 0, 8, 28, 2, 0, 0, 0, 20, 1, 0, 0, 25, 17, 2, 25, 24, 8, 0, 0, 0, 0, 0, 0, - 0, 0, 25, 24, 8, 9, 0, 0, 0, 0, 0, 0, 0, 21, 2, 25, 22, 0, 0, 25, 17, 2, 25, 24, 1, 48, 25, 24, - 1, 57, 21, 2, 25, 22, 0, 0, 25, 17, 2, 25, 24, 1, 0, 25, 24, 1, 0, 21, 2, 25, 22, 0, 0, 25, 28, - 22, 8, 21, 188, 1, 22, 0, 0, 0, 25, 28, 22, 196, 1, 21, 150, 1, 22, 0, 0, 0, 25, 28, 22, 218, - 2, 21, 66, 22, 0, 0, 0, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 24, 130, 24, 8, - 134, 8, 68, 6, 2, 101, 128, 10, 64, 2, 38, 78, 114, 1, 64, 38, 1, 192, 194, 152, 64, 70, 0, 36, - 56, 121, 64, 0, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 8, 17, 10, 29, 5, 88, 194, - 0, 35, 208, 25, 16, 70, 68, 48, 38, 17, 16, 140, 68, 98, 56, 0, 131, 4, 193, 40, 129, 161, 160, - 1, 96, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 2, 25, 76, 72, 12, 115, 112, - 97, 114, 107, 95, 115, 99, 104, 101, 109, 97, 21, 6, 0, 21, 4, 37, 2, 24, 2, 105, 100, 0, 21, - 12, 37, 2, 24, 6, 115, 116, 114, 105, 110, 103, 37, 0, 76, 28, 0, 0, 0, 21, 0, 37, 2, 24, 4, - 98, 111, 111, 108, 0, 22, 20, 25, 28, 25, 60, 38, 8, 28, 21, 4, 25, 53, 0, 6, 8, 25, 24, 2, - 105, 100, 21, 2, 22, 20, 22, 222, 1, 22, 188, 1, 38, 8, 60, 24, 8, 9, 0, 0, 0, 0, 0, 0, 0, 24, - 8, 0, 0, 0, 0, 0, 0, 0, 0, 22, 0, 40, 8, 9, 0, 0, 0, 0, 0, 0, 0, 24, 8, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 25, 28, 21, 0, 21, 0, 21, 2, 0, 22, 226, 4, 0, 22, 158, 4, 21, 22, 22, 156, 3, 21, 62, 0, - 38, 196, 1, 28, 21, 12, 25, 53, 0, 6, 8, 25, 24, 6, 115, 116, 114, 105, 110, 103, 21, 2, 22, - 20, 22, 158, 1, 22, 150, 1, 38, 196, 1, 60, 54, 0, 40, 1, 57, 24, 1, 48, 0, 25, 28, 21, 0, 21, - 0, 21, 2, 0, 22, 192, 5, 0, 22, 180, 4, 21, 24, 22, 218, 3, 21, 34, 0, 38, 218, 2, 28, 21, 0, - 25, 53, 0, 6, 8, 25, 24, 4, 98, 111, 111, 108, 21, 2, 22, 20, 22, 62, 22, 66, 38, 218, 2, 60, - 24, 1, 0, 24, 1, 0, 22, 0, 40, 1, 0, 24, 1, 0, 0, 25, 28, 21, 0, 21, 0, 21, 2, 0, 22, 158, 6, - 0, 22, 204, 4, 21, 22, 22, 252, 3, 21, 34, 0, 22, 186, 3, 22, 20, 38, 8, 22, 148, 3, 20, 0, 0, - 25, 44, 24, 24, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, - 118, 101, 114, 115, 105, 111, 110, 24, 5, 51, 46, 50, 46, 49, 0, 24, 41, 111, 114, 103, 46, 97, - 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 112, 97, 114, 113, - 117, 101, 116, 46, 114, 111, 119, 46, 109, 101, 116, 97, 100, 97, 116, 97, 24, 213, 1, 123, 34, - 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 117, 99, 116, 34, 44, 34, 102, 105, 101, 108, - 100, 115, 34, 58, 91, 123, 34, 110, 97, 109, 101, 34, 58, 34, 105, 100, 34, 44, 34, 116, 121, - 112, 101, 34, 58, 34, 108, 111, 110, 103, 34, 44, 34, 110, 117, 108, 108, 97, 98, 108, 101, 34, - 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, 125, 125, 44, - 123, 34, 110, 97, 109, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 44, 34, 116, 121, - 112, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 44, 34, 110, 117, 108, 108, 97, 98, - 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, - 125, 125, 44, 123, 34, 110, 97, 109, 101, 34, 58, 34, 98, 111, 111, 108, 34, 44, 34, 116, 121, - 112, 101, 34, 58, 34, 98, 111, 111, 108, 101, 97, 110, 34, 44, 34, 110, 117, 108, 108, 97, 98, - 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, - 125, 125, 93, 125, 0, 24, 74, 112, 97, 114, 113, 117, 101, 116, 45, 109, 114, 32, 118, 101, - 114, 115, 105, 111, 110, 32, 49, 46, 49, 50, 46, 50, 32, 40, 98, 117, 105, 108, 100, 32, 55, - 55, 101, 51, 48, 99, 56, 48, 57, 51, 51, 56, 54, 101, 99, 53, 50, 99, 51, 99, 102, 97, 54, 99, - 51, 52, 98, 55, 101, 102, 51, 51, 50, 49, 51, 50, 50, 99, 57, 52, 41, 25, 60, 28, 0, 0, 28, 0, - 0, 28, 0, 0, 0, 182, 2, 0, 0, 80, 65, 82, 49, -]; - -#[test] -fn test() -> Result<(), ParquetError> { - let mut reader = std::io::Cursor::new(FILE); - - let expected_index = vec![ - Box::new(NativeIndex:: { - primitive_type: PrimitiveType::from_physical("id".to_string(), PhysicalType::Int64), - indexes: vec![PageIndex { - min: Some(0), - max: Some(9), - null_count: Some(0), - }], - boundary_order: BoundaryOrder::Ascending, - }) as Box, - Box::new(ByteIndex { - primitive_type: PrimitiveType { - field_info: FieldInfo { - name: "string".to_string(), - repetition: Repetition::Optional, - id: None, - }, - logical_type: Some(PrimitiveLogicalType::String), - converted_type: Some(PrimitiveConvertedType::Utf8), - physical_type: PhysicalType::ByteArray, - }, - indexes: vec![PageIndex { - min: Some(b"0".to_vec()), - max: Some(b"9".to_vec()), - null_count: Some(0), - }], - boundary_order: BoundaryOrder::Ascending, - }), - Box::new(BooleanIndex { - indexes: vec![PageIndex { - min: Some(false), - max: Some(false), - null_count: Some(0), - }], - boundary_order: BoundaryOrder::Ascending, - }), - ]; - let expected_page_locations = vec![ - vec![PageLocation { - offset: 4, - compressed_page_size: 94, - first_row_index: 0, - }], - vec![PageLocation { - offset: 98, - compressed_page_size: 75, - first_row_index: 0, - }], - vec![PageLocation { - offset: 173, - compressed_page_size: 33, - first_row_index: 0, - }], - ]; - - let metadata = read_metadata(&mut reader)?; - let columns = &metadata.row_groups[0].columns(); - - let indexes = read_columns_indexes(&mut reader, columns)?; - assert_eq!(&indexes, &expected_index); - - let pages = read_pages_locations(&mut reader, columns)?; - assert_eq!(pages, expected_page_locations); - - Ok(()) -} diff --git a/crates/polars/tests/it/io/parquet/read/mod.rs b/crates/polars/tests/it/io/parquet/read/mod.rs index 99b1c1b7c9dd..c4ba7d5e418e 100644 --- a/crates/polars/tests/it/io/parquet/read/mod.rs +++ b/crates/polars/tests/it/io/parquet/read/mod.rs @@ -4,10 +4,11 @@ mod binary; /// but OTOH it has no external dependencies and is very familiar to Rust developers. mod boolean; mod dictionary; +pub(crate) mod file; mod fixed_binary; -mod indexes; mod primitive; mod primitive_nested; +pub(crate) mod row_group; mod struct_; mod utils; @@ -16,11 +17,9 @@ use std::fs::File; use dictionary::DecodedDictPage; use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; use polars_parquet::parquet::error::{ParquetError, ParquetResult}; -use polars_parquet::parquet::metadata::ColumnChunkMetaData; +use polars_parquet::parquet::metadata::ColumnChunkMetadata; use polars_parquet::parquet::page::DataPage; -use polars_parquet::parquet::read::{ - get_column_iterator, get_field_columns, read_metadata, BasicDecompressor, -}; +use polars_parquet::parquet::read::{get_column_iterator, read_metadata, BasicDecompressor}; use polars_parquet::parquet::schema::types::{GroupConvertedType, ParquetType}; use polars_parquet::parquet::schema::Repetition; use polars_parquet::parquet::types::int96_to_i64_ns; @@ -142,9 +141,9 @@ pub fn page_to_array(page: &DataPage, dict: Option<&DecodedDictPage>) -> Parquet /// Reads columns into an [`Array`]. /// This is CPU-intensive: decompress, decode and de-serialize. -pub fn columns_to_array(mut columns: I, field: &ParquetType) -> ParquetResult +pub fn columns_to_array<'a, I>(mut columns: I, field: &ParquetType) -> ParquetResult where - I: Iterator>, + I: Iterator>, { let mut validity = vec![]; let mut has_filled = false; @@ -157,6 +156,7 @@ where .map(|dict| dictionary::deserialize(&dict, column.physical_type())) .transpose()?; while let Some(page) = iterator.next().transpose()? { + let page = page.decompress(&mut iterator)?; if !has_filled { struct_::extend_validity(&mut validity, &page)?; } @@ -200,11 +200,11 @@ pub fn read_column( reader, &metadata.row_groups[row_group], field.name(), - None, usize::MAX, ); - let mut statistics = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + let mut statistics = metadata.row_groups[row_group] + .columns_under_root_iter(field.name()) .map(|column_meta| column_meta.statistics().transpose()) .collect::>>()?; diff --git a/crates/polars/tests/it/io/parquet/read/primitive.rs b/crates/polars/tests/it/io/parquet/read/primitive.rs index d9665f353c53..960c502fb82d 100644 --- a/crates/polars/tests/it/io/parquet/read/primitive.rs +++ b/crates/polars/tests/it/io/parquet/read/primitive.rs @@ -26,7 +26,6 @@ impl<'a, T: NativeType> PageState<'a, T> { page: &'a DataPage, dict: Option<&'a PrimitivePageDict>, ) -> Result { - assert!(page.selected_rows().is_none()); NativePageState::try_new(page, dict).map(Self::Nominal) } } diff --git a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs index e4abd2046432..36fdb254420a 100644 --- a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs +++ b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs @@ -1,6 +1,7 @@ +use polars_parquet::parquet::encoding::bitpacked::{Unpackable, Unpacked}; use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; use polars_parquet::parquet::encoding::{bitpacked, uleb128, Encoding}; -use polars_parquet::parquet::error::ParquetError; +use polars_parquet::parquet::error::{ParquetError, ParquetResult}; use polars_parquet::parquet::page::{split_buffer, DataPage, EncodedSplitBuffer}; use polars_parquet::parquet::read::levels::get_bit_width; use polars_parquet::parquet::types::NativeType; @@ -171,6 +172,51 @@ pub fn page_to_array( } } +pub struct DecoderIter<'a, T: Unpackable> { + pub(crate) decoder: bitpacked::Decoder<'a, T>, + pub(crate) buffered: T::Unpacked, + pub(crate) unpacked_start: usize, + pub(crate) unpacked_end: usize, +} + +impl<'a, T: Unpackable> Iterator for DecoderIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + if self.unpacked_start >= self.unpacked_end { + let length; + (self.buffered, length) = self.decoder.chunked().next_inexact()?; + debug_assert!(length > 0); + self.unpacked_start = 1; + self.unpacked_end = length; + return Some(self.buffered[0]); + } + + let v = self.buffered[self.unpacked_start]; + self.unpacked_start += 1; + Some(v) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.decoder.len() + self.unpacked_end - self.unpacked_start; + (len, Some(len)) + } +} + +impl<'a, T: Unpackable> ExactSizeIterator for DecoderIter<'a, T> {} + +impl<'a, T: Unpackable> DecoderIter<'a, T> { + pub fn new(packed: &'a [u8], num_bits: usize, length: usize) -> ParquetResult { + assert!(num_bits > 0); + Ok(Self { + decoder: bitpacked::Decoder::try_new(packed, num_bits, length)?, + buffered: T::Unpacked::zero(), + unpacked_start: 0, + unpacked_end: 0, + }) + } +} + fn read_dict_array( rep_levels: &[u8], def_levels: &[u8], @@ -188,8 +234,7 @@ fn read_dict_array( let (_, consumed) = uleb128::decode(values); let values = &values[consumed..]; - let indices = bitpacked::Decoder::::try_new(values, bit_width as usize, length as usize)? - .collect_into_iter(); + let indices = DecoderIter::::new(values, bit_width as usize, length as usize)?; let values = indices.map(|id| dict_values[id as usize]); diff --git a/crates/polars-parquet/src/arrow/read/row_group.rs b/crates/polars/tests/it/io/parquet/read/row_group.rs similarity index 63% rename from crates/polars-parquet/src/arrow/read/row_group.rs rename to crates/polars/tests/it/io/parquet/read/row_group.rs index 0156569b4cd9..f23ee779b120 100644 --- a/crates/polars-parquet/src/arrow/read/row_group.rs +++ b/crates/polars/tests/it/io/parquet/read/row_group.rs @@ -3,15 +3,14 @@ use std::io::{Read, Seek}; use arrow::array::Array; use arrow::datatypes::Field; use arrow::record_batch::RecordBatchT; +use polars::prelude::ArrowSchema; use polars_error::PolarsResult; +use polars_parquet::arrow::read::{column_iter_to_arrays, Filter}; +use polars_parquet::parquet::metadata::ColumnChunkMetadata; +use polars_parquet::parquet::read::{BasicDecompressor, PageReader}; +use polars_parquet::read::RowGroupMetaData; use polars_utils::mmap::MemReader; -use super::{ArrayIter, RowGroupMetaData}; -use crate::arrow::read::column_iter_to_arrays; -use crate::arrow::read::deserialize::Filter; -use crate::parquet::metadata::ColumnChunkMetaData; -use crate::parquet::read::{BasicDecompressor, PageReader}; - /// An [`Iterator`] of [`RecordBatchT`] that (dynamically) adapts a vector of iterators of [`Array`] into /// an iterator of [`RecordBatchT`]. /// @@ -23,7 +22,7 @@ use crate::parquet::read::{BasicDecompressor, PageReader}; pub struct RowGroupDeserializer { num_rows: usize, remaining_rows: usize, - column_chunks: Vec>, + column_chunks: Vec>, } impl RowGroupDeserializer { @@ -32,11 +31,7 @@ impl RowGroupDeserializer { /// # Panic /// This function panics iff any of the `column_chunks` /// do not return an array with an equal length. - pub fn new( - column_chunks: Vec>, - num_rows: usize, - limit: Option, - ) -> Self { + pub fn new(column_chunks: Vec>, num_rows: usize, limit: Option) -> Self { Self { num_rows, remaining_rows: limit.unwrap_or(usize::MAX).min(num_rows), @@ -57,12 +52,7 @@ impl Iterator for RowGroupDeserializer { if self.remaining_rows == 0 { return None; } - let chunk = self - .column_chunks - .iter_mut() - .map(|iter| iter.next().unwrap()) - .collect::>>() - .and_then(RecordBatchT::try_new); + let chunk = RecordBatchT::try_new(std::mem::take(&mut self.column_chunks)); self.remaining_rows = self.remaining_rows.saturating_sub( chunk .as_ref() @@ -74,57 +64,31 @@ impl Iterator for RowGroupDeserializer { } } -/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. -/// For non-nested parquet types, this returns a single column -pub fn get_field_columns<'a>( - columns: &'a [ColumnChunkMetaData], - field_name: &str, -) -> Vec<&'a ColumnChunkMetaData> { - columns - .iter() - .filter(|x| x.descriptor().path_in_schema[0] == field_name) - .collect() -} - -/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. -/// For non-nested parquet types, this returns a single column -pub fn get_field_pages<'a, T>( - columns: &'a [ColumnChunkMetaData], - items: &'a [T], - field_name: &str, -) -> Vec<&'a T> { - columns - .iter() - .zip(items) - .filter(|(metadata, _)| metadata.descriptor().path_in_schema[0] == field_name) - .map(|(_, item)| item) - .collect() -} - /// Reads all columns that are part of the parquet field `field_name` /// # Implementation /// This operation is IO-bounded `O(C)` where C is the number of columns associated to /// the field (one for non-nested types) pub fn read_columns<'a, R: Read + Seek>( reader: &mut R, - columns: &'a [ColumnChunkMetaData], - field_name: &str, -) -> PolarsResult)>> { - get_field_columns(columns, field_name) - .into_iter() + row_group_metadata: &'a RowGroupMetaData, + field_name: &'a str, +) -> PolarsResult)>> { + row_group_metadata + .columns_under_root_iter(field_name) .map(|meta| _read_single_column(reader, meta)) .collect() } fn _read_single_column<'a, R>( reader: &mut R, - meta: &'a ColumnChunkMetaData, -) -> PolarsResult<(&'a ColumnChunkMetaData, Vec)> + meta: &'a ColumnChunkMetadata, +) -> PolarsResult<(&'a ColumnChunkMetadata, Vec)> where R: Read + Seek, { - let (start, length) = meta.byte_range(); - reader.seek(std::io::SeekFrom::Start(start))?; + let byte_range = meta.byte_range(); + let length = byte_range.end - byte_range.start; + reader.seek(std::io::SeekFrom::Start(byte_range.start))?; let mut chunk = vec![]; chunk.try_reserve(length as usize)?; @@ -134,11 +98,11 @@ where /// Converts a vector of columns associated with the parquet field whose name is [`Field`] /// to an iterator of [`Array`], [`ArrayIter`] of chunk size `chunk_size`. -pub fn to_deserializer<'a>( - columns: Vec<(&ColumnChunkMetaData, Vec)>, +pub fn to_deserializer( + columns: Vec<(&ColumnChunkMetadata, Vec)>, field: Field, filter: Option, -) -> PolarsResult> { +) -> PolarsResult> { let (columns, types): (Vec<_>, Vec<_>) = columns .into_iter() .map(|(column_meta, chunk)| { @@ -146,7 +110,6 @@ pub fn to_deserializer<'a>( let pages = PageReader::new( MemReader::from_vec(chunk), column_meta, - std::sync::Arc::new(|_, _| true), vec![], len * 2 + 1024, ); @@ -170,22 +133,22 @@ pub fn to_deserializer<'a>( /// This operation is single-threaded. For readers with stronger invariants /// (e.g. implement [`Clone`]) you can use [`read_columns`] to read multiple columns at once /// and convert them to [`ArrayIter`] via [`to_deserializer`]. -pub fn read_columns_many<'a, R: Read + Seek>( +pub fn read_columns_many( reader: &mut R, row_group: &RowGroupMetaData, - fields: Vec, + fields: &ArrowSchema, filter: Option, -) -> PolarsResult>> { +) -> PolarsResult>> { // reads all the necessary columns for all fields from the row group // This operation is IO-bounded `O(C)` where C is the number of columns in the row group let field_columns = fields - .iter() - .map(|field| read_columns(reader, row_group.columns(), &field.name)) + .iter_values() + .map(|field| read_columns(reader, row_group, &field.name)) .collect::>>()?; field_columns .into_iter() - .zip(fields) - .map(|(columns, field)| to_deserializer(columns, field, filter.clone())) + .zip(fields.iter_values().cloned()) + .map(|(columns, field)| to_deserializer(columns.clone(), field, filter.clone())) .collect() } diff --git a/crates/polars/tests/it/io/parquet/roundtrip.rs b/crates/polars/tests/it/io/parquet/roundtrip.rs index aa4eacb0e04d..6e105002fa53 100644 --- a/crates/polars/tests/it/io/parquet/roundtrip.rs +++ b/crates/polars/tests/it/io/parquet/roundtrip.rs @@ -10,14 +10,16 @@ use polars_parquet::write::{ CompressionOptions, Encoding, RowGroupIterator, StatisticsOptions, Version, }; +use crate::io::parquet::read::file::FileReader; + fn round_trip( array: &ArrayRef, version: Version, compression: CompressionOptions, encodings: Vec, ) -> PolarsResult<()> { - let field = Field::new("a1", array.data_type().clone(), true); - let schema = ArrowSchema::from(vec![field]); + let field = Field::new("a1".into(), array.dtype().clone(), true); + let schema = ArrowSchema::from_iter([field]); let options = WriteOptions { statistics: StatisticsOptions::full(), @@ -53,7 +55,7 @@ fn round_trip( .collect(); // we can then read the row groups into chunks - let chunks = polars_parquet::read::FileReader::new(reader, row_groups, schema, None); + let chunks = FileReader::new(reader, row_groups, schema, None); let mut arrays = vec![]; for chunk in chunks { diff --git a/crates/polars/tests/it/io/parquet/write/binary.rs b/crates/polars/tests/it/io/parquet/write/binary.rs index bb9abc62c258..8176a42cbf83 100644 --- a/crates/polars/tests/it/io/parquet/write/binary.rs +++ b/crates/polars/tests/it/io/parquet/write/binary.rs @@ -83,6 +83,6 @@ pub fn array_to_page_v1( DataPageHeader::V1(header), CowBuffer::Owned(buffer), descriptor.clone(), - Some(array.len()), + array.len(), ))) } diff --git a/crates/polars/tests/it/io/parquet/write/indexes.rs b/crates/polars/tests/it/io/parquet/write/indexes.rs deleted file mode 100644 index 3f5f15c92828..000000000000 --- a/crates/polars/tests/it/io/parquet/write/indexes.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::io::Cursor; - -use polars_parquet::parquet::compression::CompressionOptions; -use polars_parquet::parquet::error::ParquetResult; -use polars_parquet::parquet::indexes::{ - BoundaryOrder, Index, NativeIndex, PageIndex, PageLocation, -}; -use polars_parquet::parquet::metadata::SchemaDescriptor; -use polars_parquet::parquet::read::{read_columns_indexes, read_metadata, read_pages_locations}; -use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType, PrimitiveType}; -use polars_parquet::parquet::write::{ - Compressor, DynIter, DynStreamingIterator, FileWriter, Version, WriteOptions, -}; - -use super::primitive::array_to_page_v1; - -fn write_file() -> ParquetResult> { - let page1 = vec![Some(0), Some(1), None, Some(3), Some(4), Some(5), Some(6)]; - let page2 = vec![Some(10), Some(11)]; - - let options = WriteOptions { - write_statistics: true, - version: Version::V1, - }; - - let schema = SchemaDescriptor::new( - "schema".to_string(), - vec![ParquetType::from_physical( - "col1".to_string(), - PhysicalType::Int32, - )], - ); - - let pages = vec![ - array_to_page_v1::(&page1, &options, &schema.columns()[0].descriptor), - array_to_page_v1::(&page2, &options, &schema.columns()[0].descriptor), - ]; - - let pages = DynStreamingIterator::new(Compressor::new( - DynIter::new(pages.into_iter()), - CompressionOptions::Uncompressed, - vec![], - )); - let columns = std::iter::once(Ok(pages)); - - let writer = Cursor::new(vec![]); - let mut writer = FileWriter::new(writer, schema, options, None); - - writer.write(DynIter::new(columns))?; - writer.end(None)?; - - Ok(writer.into_inner().into_inner()) -} - -#[test] -fn read_indexes_and_locations() -> ParquetResult<()> { - let data = write_file()?; - let mut reader = Cursor::new(data); - - let metadata = read_metadata(&mut reader)?; - - let columns = &metadata.row_groups[0].columns(); - - let expected_page_locations = vec![vec![ - PageLocation { - offset: 4, - compressed_page_size: 63, - first_row_index: 0, - }, - PageLocation { - offset: 67, - compressed_page_size: 47, - first_row_index: 7, - }, - ]]; - let expected_index = vec![Box::new(NativeIndex:: { - primitive_type: PrimitiveType::from_physical("col1".to_string(), PhysicalType::Int32), - indexes: vec![ - PageIndex { - min: Some(0), - max: Some(6), - null_count: Some(1), - }, - PageIndex { - min: Some(10), - max: Some(11), - null_count: Some(0), - }, - ], - boundary_order: BoundaryOrder::Unordered, - }) as Box]; - - let indexes = read_columns_indexes(&mut reader, columns)?; - assert_eq!(&indexes, &expected_index); - - let pages = read_pages_locations(&mut reader, columns)?; - assert_eq!(pages, expected_page_locations); - - Ok(()) -} diff --git a/crates/polars/tests/it/io/parquet/write/mod.rs b/crates/polars/tests/it/io/parquet/write/mod.rs index 7f066fe726e4..4403277a0552 100644 --- a/crates/polars/tests/it/io/parquet/write/mod.rs +++ b/crates/polars/tests/it/io/parquet/write/mod.rs @@ -1,5 +1,4 @@ mod binary; -mod indexes; mod primitive; mod sidecar; @@ -68,8 +67,8 @@ fn test_column(column: &str, compression: CompressionOptions) -> ParquetResult<( }; let schema = SchemaDescriptor::new( - "schema".to_string(), - vec![ParquetType::from_physical("col".to_string(), type_)], + "schema".into(), + vec![ParquetType::from_physical("col".into(), type_)], ); let a = schema.columns(); @@ -182,9 +181,9 @@ fn basic() -> ParquetResult<()> { }; let schema = SchemaDescriptor::new( - "schema".to_string(), + "schema".into(), vec![ParquetType::from_physical( - "col".to_string(), + "col".into(), PhysicalType::Int32, )], ); @@ -214,7 +213,11 @@ fn basic() -> ParquetResult<()> { // validated against an equivalent array produced by pyarrow. let expected = 51; assert_eq!( - metadata.row_groups[0].columns()[0].uncompressed_size(), + metadata.row_groups[0] + .columns_under_root_iter("col") + .next() + .unwrap() + .uncompressed_size(), expected ); diff --git a/crates/polars/tests/it/io/parquet/write/primitive.rs b/crates/polars/tests/it/io/parquet/write/primitive.rs index 044925c5bb11..210bf0e6cefb 100644 --- a/crates/polars/tests/it/io/parquet/write/primitive.rs +++ b/crates/polars/tests/it/io/parquet/write/primitive.rs @@ -74,6 +74,6 @@ pub fn array_to_page_v1( DataPageHeader::V1(header), CowBuffer::Owned(buffer), descriptor.clone(), - Some(array.len()), + array.len(), ))) } diff --git a/crates/polars/tests/it/io/parquet/write/sidecar.rs b/crates/polars/tests/it/io/parquet/write/sidecar.rs index 4df35d9e817d..00f4397ba6f4 100644 --- a/crates/polars/tests/it/io/parquet/write/sidecar.rs +++ b/crates/polars/tests/it/io/parquet/write/sidecar.rs @@ -6,11 +6,8 @@ use polars_parquet::parquet::write::{write_metadata_sidecar, FileWriter, Version #[test] fn basic() -> Result<(), ParquetError> { let schema = SchemaDescriptor::new( - "schema".to_string(), - vec![ParquetType::from_physical( - "c1".to_string(), - PhysicalType::Int32, - )], + "schema".into(), + vec![ParquetType::from_physical("c1".into(), PhysicalType::Int32)], ); let mut metadatas = vec![]; diff --git a/crates/polars/tests/it/joins.rs b/crates/polars/tests/it/joins.rs index 37ed6e2720d5..0fa0ba1c66a9 100644 --- a/crates/polars/tests/it/joins.rs +++ b/crates/polars/tests/it/joins.rs @@ -36,10 +36,14 @@ fn join_nans_outer() -> PolarsResult<()> { #[test] #[cfg(feature = "lazy")] fn join_empty_datasets() -> PolarsResult<()> { - let a = DataFrame::new(Vec::from([Series::new_empty("foo", &DataType::Int64)])).unwrap(); + let a = DataFrame::new(Vec::from([Series::new_empty( + "foo".into(), + &DataType::Int64, + )])) + .unwrap(); let b = DataFrame::new(Vec::from([ - Series::new_empty("foo", &DataType::Int64), - Series::new_empty("bar", &DataType::Int64), + Series::new_empty("foo".into(), &DataType::Int64), + Series::new_empty("bar".into(), &DataType::Int64), ])) .unwrap(); diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs index 33662c442959..ad433e139775 100644 --- a/crates/polars/tests/it/lazy/aggregation.rs +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -4,7 +4,7 @@ use super::*; #[cfg(feature = "temporal")] fn test_lazy_agg() { let s0 = DateChunked::parse_from_str_slice( - "date", + "date".into(), &[ "2020-08-21", "2020-08-21", @@ -15,8 +15,8 @@ fn test_lazy_agg() { "%Y-%m-%d", ) .into_series(); - let s1 = Series::new("temp", [20, 10, 7, 9, 1].as_ref()); - let s2 = Series::new("rain", [0.2, 0.1, 0.3, 0.1, 0.01].as_ref()); + let s1 = Series::new("temp".into(), [20, 10, 7, 9, 1].as_ref()); + let s2 = Series::new("rain".into(), [0.2, 0.1, 0.3, 0.1, 0.01].as_ref()); let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); let lf = df @@ -33,7 +33,7 @@ fn test_lazy_agg() { let new = lf.collect().unwrap(); let min = new.column("min").unwrap(); - assert_eq!(min, &Series::new("min", [0.1f64, 0.01, 0.1])); + assert_eq!(min, &Series::new("min".into(), [0.1f64, 0.01, 0.1])); } #[test] diff --git a/crates/polars/tests/it/lazy/cwc.rs b/crates/polars/tests/it/lazy/cwc.rs index ae836354982e..2ad0ab11ede4 100644 --- a/crates/polars/tests/it/lazy/cwc.rs +++ b/crates/polars/tests/it/lazy/cwc.rs @@ -76,7 +76,7 @@ fn fuzz_cluster_with_columns() { let column = rng.gen_range(0..unused_cols.len()); let column = unused_cols.swap_remove(column); - series.push(Series::new(to_str!(column), vec![rnd_prime(rng)])); + series.push(Series::new(to_str!(column).into(), vec![rnd_prime(rng)])); used_cols.push(column); } diff --git a/crates/polars/tests/it/lazy/expressions/apply.rs b/crates/polars/tests/it/lazy/expressions/apply.rs index d7814bc04c60..8006d7da8291 100644 --- a/crates/polars/tests/it/lazy/expressions/apply.rs +++ b/crates/polars/tests/it/lazy/expressions/apply.rs @@ -64,6 +64,7 @@ fn test_groups_update_binary_shift_log() -> PolarsResult<()> { } #[test] +#[cfg(feature = "cum_agg")] fn test_expand_list() -> PolarsResult<()> { let out = df![ "a" => [1, 2], diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs index 9e0acb248acc..52ac97c56e62 100644 --- a/crates/polars/tests/it/lazy/expressions/arity.rs +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -197,7 +197,7 @@ fn test_update_groups_in_cast() -> PolarsResult<()> { let expected = df![ "group" => ["A" ,"B"], - "id"=> [AnyValue::List(Series::new("", [-2i64, -1])), AnyValue::List(Series::new("", [-2i64, -1, -1]))] + "id"=> [AnyValue::List(Series::new("".into(), [-2i64, -1])), AnyValue::List(Series::new("".into(), [-2i64, -1, -1]))] ]?; assert!(out.equals(&expected)); @@ -273,18 +273,18 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .group_by([col("name")]) .agg([when(col("value").sum().eq(lit(3))) .then(col("value").rank(Default::default(), None)) - .otherwise(lit(Series::new("", &[10 as IdxSize])))]) + .otherwise(lit(Series::new("".into(), &[10 as IdxSize])))]) .sort(["name"], Default::default()) .collect()?; let out = out.column("value")?; assert_eq!( out.get(0)?, - AnyValue::List(Series::new("", &[1 as IdxSize, 2 as IdxSize])) + AnyValue::List(Series::new("".into(), &[1 as IdxSize, 2 as IdxSize])) ); assert_eq!( out.get(1)?, - AnyValue::List(Series::new("", &[10 as IdxSize, 10 as IdxSize])) + AnyValue::List(Series::new("".into(), &[10 as IdxSize, 10 as IdxSize])) ); let out = df @@ -292,7 +292,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .lazy() .group_by([col("name")]) .agg([when(col("value").sum().eq(lit(3))) - .then(lit(Series::new("", &[10 as IdxSize])).alias("value")) + .then(lit(Series::new("".into(), &[10 as IdxSize])).alias("value")) .otherwise(col("value").rank(Default::default(), None))]) .sort(["name"], Default::default()) .collect()?; @@ -300,11 +300,11 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { let out = out.column("value")?; assert_eq!( out.get(1)?, - AnyValue::List(Series::new("", &[1 as IdxSize, 2])) + AnyValue::List(Series::new("".into(), &[1 as IdxSize, 2])) ); assert_eq!( out.get(0)?, - AnyValue::List(Series::new("", &[10 as IdxSize, 10 as IdxSize])) + AnyValue::List(Series::new("".into(), &[10 as IdxSize, 10 as IdxSize])) ); let out = df diff --git a/crates/polars/tests/it/lazy/expressions/expand.rs b/crates/polars/tests/it/lazy/expressions/expand.rs index 27d8ee0ac1ad..69572ae0a454 100644 --- a/crates/polars/tests/it/lazy/expressions/expand.rs +++ b/crates/polars/tests/it/lazy/expressions/expand.rs @@ -13,7 +13,7 @@ fn test_expand_datetimes_3042() -> PolarsResult<()> { .and_hms_opt(0, 0, 0) .unwrap(); let date_range = polars_time::date_range( - "dt1", + "dt1".into(), low, high, Duration::parse("1w"), diff --git a/crates/polars/tests/it/lazy/expressions/is_in.rs b/crates/polars/tests/it/lazy/expressions/is_in.rs index e718b01ea032..73591af48328 100644 --- a/crates/polars/tests/it/lazy/expressions/is_in.rs +++ b/crates/polars/tests/it/lazy/expressions/is_in.rs @@ -6,7 +6,7 @@ fn test_is_in() -> PolarsResult<()> { "x" => [1, 2, 3], "y" => ["a", "b", "c"] ]?; - let s = Series::new("a", ["a", "b"]); + let s = Series::new("a".into(), ["a", "b"]); let out = df .lazy() diff --git a/crates/polars/tests/it/lazy/expressions/literals.rs b/crates/polars/tests/it/lazy/expressions/literals.rs new file mode 100644 index 000000000000..a2e1cf7822b4 --- /dev/null +++ b/crates/polars/tests/it/lazy/expressions/literals.rs @@ -0,0 +1,19 @@ +use super::*; + +#[test] +fn test_datetime_as_lit() { + let Expr::Alias(e, name) = datetime(Default::default()) else { + panic!() + }; + assert_eq!(name, "datetime"); + assert!(matches!(e.as_ref(), Expr::Literal(_))) +} + +#[test] +fn test_duration_as_lit() { + let Expr::Alias(e, name) = duration(Default::default()) else { + panic!() + }; + assert_eq!(name, "duration"); + assert!(matches!(e.as_ref(), Expr::Literal(_))) +} diff --git a/crates/polars/tests/it/lazy/expressions/mod.rs b/crates/polars/tests/it/lazy/expressions/mod.rs index 70be1828a28d..e52e550c5090 100644 --- a/crates/polars/tests/it/lazy/expressions/mod.rs +++ b/crates/polars/tests/it/lazy/expressions/mod.rs @@ -4,6 +4,7 @@ mod expand; mod filter; #[cfg(feature = "is_in")] mod is_in; +mod literals; mod slice; mod window; diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs index 9865a3a54380..d617dd46574a 100644 --- a/crates/polars/tests/it/lazy/expressions/window.rs +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -164,6 +164,7 @@ fn test_sort_by_in_groups() -> PolarsResult<()> { Ok(()) } #[test] +#[cfg(feature = "cum_agg")] fn test_literal_window_fn() -> PolarsResult<()> { let df = df![ "chars" => ["a", "a", "b"] @@ -216,7 +217,7 @@ fn test_window_mapping() -> PolarsResult<()> { .select([(lit(10) + col("A")).alias("foo").over([col("fruits")])]) .collect()?; - let expected = Series::new("foo", [11, 12, 13, 14, 15]); + let expected = Series::new("foo".into(), [11, 12, 13, 14, 15]); assert!(out.column("foo")?.equals(&expected)); let out = df @@ -231,7 +232,7 @@ fn test_window_mapping() -> PolarsResult<()> { .over([col("fruits")]), ]) .collect()?; - let expected = Series::new("foo", [11, 12, 8, 9, 15]); + let expected = Series::new("foo".into(), [11, 12, 8, 9, 15]); assert!(out.column("foo")?.equals(&expected)); let out = df @@ -246,7 +247,7 @@ fn test_window_mapping() -> PolarsResult<()> { .over([col("fruits")]), ]) .collect()?; - let expected = Series::new("foo", [None, Some(3), None, Some(-1), Some(-1)]); + let expected = Series::new("foo".into(), [None, Some(3), None, Some(-1), Some(-1)]); assert!(out.column("foo")?.equals_missing(&expected)); // now sorted @@ -258,7 +259,7 @@ fn test_window_mapping() -> PolarsResult<()> { .lazy() .select([(lit(10) + col("A")).alias("foo").over([col("fruits")])]) .collect()?; - let expected = Series::new("foo", [13, 14, 11, 12, 15]); + let expected = Series::new("foo".into(), [13, 14, 11, 12, 15]); assert!(out.column("foo")?.equals(&expected)); let out = df @@ -274,7 +275,7 @@ fn test_window_mapping() -> PolarsResult<()> { ]) .collect()?; - let expected = Series::new("foo", [8, 9, 11, 12, 15]); + let expected = Series::new("foo".into(), [8, 9, 11, 12, 15]); assert!(out.column("foo")?.equals(&expected)); let out = df @@ -289,7 +290,7 @@ fn test_window_mapping() -> PolarsResult<()> { ]) .collect()?; - let expected = Series::new("foo", [None, Some(-1), None, Some(3), Some(-1)]); + let expected = Series::new("foo".into(), [None, Some(-1), None, Some(3), Some(-1)]); assert!(out.column("foo")?.equals_missing(&expected)); Ok(()) diff --git a/crates/polars/tests/it/lazy/exprs.rs b/crates/polars/tests/it/lazy/exprs.rs index 66ccb4a7e444..45d550ae85a1 100644 --- a/crates/polars/tests/it/lazy/exprs.rs +++ b/crates/polars/tests/it/lazy/exprs.rs @@ -7,16 +7,19 @@ fn fuzz_exprs() { use rand::Rng; let lf = DataFrame::new(vec![ - Series::new("A", vec![1, 2, 3, 4, 5]), - Series::new("B", vec![Some(5), Some(4), None, Some(2), Some(1)]), - Series::new("C", vec!["str", "", "a quite long string", "my", "string"]), + Series::new("A".into(), vec![1, 2, 3, 4, 5]), + Series::new("B".into(), vec![Some(5), Some(4), None, Some(2), Some(1)]), + Series::new( + "C".into(), + vec!["str", "", "a quite long string", "my", "string"], + ), ]) .unwrap() .lazy(); let empty = DataFrame::new(vec![ - Series::new("A", Vec::::new()), - Series::new("B", Vec::::new()), - Series::new("C", Vec::<&str>::new()), + Series::new("A".into(), Vec::::new()), + Series::new("B".into(), Vec::::new()), + Series::new("C".into(), Vec::<&str>::new()), ]) .unwrap() .lazy(); diff --git a/crates/polars/tests/it/lazy/group_by.rs b/crates/polars/tests/it/lazy/group_by.rs index 1ccb481d6ee0..ac76e4921e40 100644 --- a/crates/polars/tests/it/lazy/group_by.rs +++ b/crates/polars/tests/it/lazy/group_by.rs @@ -77,7 +77,10 @@ fn test_filter_diff_arithmetic() -> PolarsResult<()> { .collect()?; let out = out.column("diff")?; - assert_eq!(out, &Series::new("diff", &[None, Some(26), Some(6), None])); + assert_eq!( + out, + &Series::new("diff".into(), &[None, Some(26), Some(6), None]) + ); Ok(()) } @@ -120,7 +123,7 @@ fn test_group_by_agg_list_with_not_aggregated() -> PolarsResult<()> { let out = out.explode()?; assert_eq!( out, - Series::new("value", &[0, 2, 1, 3, 2, 2, 7, 2, 3, 1, 2, 1]) + Series::new("value".into(), &[0, 2, 1, 3, 2, 2, 7, 2, 3, 1, 2, 1]) ); Ok(()) } @@ -174,7 +177,7 @@ fn test_filter_aggregated_expression() -> PolarsResult<()> { assert_eq!( x.get(1).unwrap(), - AnyValue::List(Series::new("", [0, 1, 2, 3, 4])) + AnyValue::List(Series::new("".into(), [0, 1, 2, 3, 4])) ); Ok(()) } diff --git a/crates/polars/tests/it/lazy/group_by_dynamic.rs b/crates/polars/tests/it/lazy/group_by_dynamic.rs index 6c65a4041ec8..4db863551faa 100644 --- a/crates/polars/tests/it/lazy/group_by_dynamic.rs +++ b/crates/polars/tests/it/lazy/group_by_dynamic.rs @@ -22,7 +22,7 @@ fn test_group_by_dynamic_week_bounds() -> PolarsResult<()> { .and_hms_opt(0, 0, 0) .unwrap(); let range = polars_time::date_range( - "dt", + "dt".into(), start, stop, Duration::parse("1d"), @@ -32,7 +32,7 @@ fn test_group_by_dynamic_week_bounds() -> PolarsResult<()> { )? .into_series(); - let a = Int32Chunked::full("a", 1, range.len()); + let a = Int32Chunked::full("a".into(), 1, range.len()); let df = df![ "dt" => range, "a" => a diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs index 192c6150d7c0..ac180917fab3 100644 --- a/crates/polars/tests/it/lazy/predicate_queries.rs +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -1,5 +1,5 @@ // used only if feature="is_in", feature="dtype-categorical" -#[allow(unused_imports)] +#[cfg(all(feature = "is_in", feature = "dtype-categorical"))] use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; use super::*; @@ -135,7 +135,7 @@ fn test_is_in_categorical_3420() -> PolarsResult<()> { disable_string_cache(); let _sc = StringCacheHolder::hold(); - let s = Series::new("x", ["a", "b", "c"]) + let s = Series::new("x".into(), ["a", "b", "c"]) .strict_cast(&DataType::Categorical(None, Default::default()))?; let out = df .lazy() diff --git a/crates/polars/tests/it/lazy/projection_queries.rs b/crates/polars/tests/it/lazy/projection_queries.rs index 496b13ab0aea..03b7a44bc114 100644 --- a/crates/polars/tests/it/lazy/projection_queries.rs +++ b/crates/polars/tests/it/lazy/projection_queries.rs @@ -106,6 +106,7 @@ fn test_many_aliasing_projections_5070() -> PolarsResult<()> { } #[test] +#[cfg(feature = "cum_agg")] fn test_projection_5086() -> PolarsResult<()> { let df = df![ "a" => ["a", "a", "a", "b"], @@ -146,8 +147,8 @@ fn test_projection_5086() -> PolarsResult<()> { #[cfg(feature = "dtype-struct")] fn test_unnest_pushdown() -> PolarsResult<()> { let df = df![ - "collection" => Series::full_null("", 1, &DataType::Int32), - "users" => Series::full_null("", 1, &DataType::List(Box::new(DataType::Struct(vec![Field::new("email", DataType::String)])))), + "collection" => Series::full_null("".into(), 1, &DataType::Int32), + "users" => Series::full_null("".into(), 1, &DataType::List(Box::new(DataType::Struct(vec![Field::new("email".into(), DataType::String)])))), ]?; let out = df diff --git a/crates/polars/tests/it/lazy/queries.rs b/crates/polars/tests/it/lazy/queries.rs index 8513efe0bc68..0be10b20f60e 100644 --- a/crates/polars/tests/it/lazy/queries.rs +++ b/crates/polars/tests/it/lazy/queries.rs @@ -4,7 +4,7 @@ use super::*; #[test] fn test_with_duplicate_column_empty_df() { - let a = Int32Chunked::from_slice("a", &[]); + let a = Int32Chunked::from_slice("a".into(), &[]); assert_eq!( DataFrame::new(vec![a.into_series()]) @@ -137,7 +137,7 @@ fn test_sorted_path() -> PolarsResult<()> { let payloads = &[1, 2, 3]; let df = df![ - "a"=> [AnyValue::List(Series::new("", payloads)), AnyValue::List(Series::new("", payloads)), AnyValue::List(Series::new("", payloads))] + "a"=> [AnyValue::List(Series::new("".into(), payloads)), AnyValue::List(Series::new("".into(), payloads)), AnyValue::List(Series::new("".into(), payloads))] ]?; let out = df @@ -234,11 +234,11 @@ fn test_apply_multiple_columns() -> PolarsResult<()> { #[test] fn test_group_by_on_lists() -> PolarsResult<()> { - let s0 = Series::new("", [1i32, 2, 3]); - let s1 = Series::new("groups", [4i32, 5]); + let s0 = Series::new("".into(), [1i32, 2, 3]); + let s1 = Series::new("groups".into(), [4i32, 5]); let mut builder = - ListPrimitiveChunkedBuilder::::new("arrays", 10, 10, DataType::Int32); + ListPrimitiveChunkedBuilder::::new("arrays".into(), 10, 10, DataType::Int32); builder.append_series(&s0).unwrap(); builder.append_series(&s1).unwrap(); let s2 = builder.finish().into_series(); diff --git a/crates/polars/tests/it/schema.rs b/crates/polars/tests/it/schema.rs index a636c06961b8..c791367f7546 100644 --- a/crates/polars/tests/it/schema.rs +++ b/crates/polars/tests/it/schema.rs @@ -8,9 +8,9 @@ fn test_schema_rename() { fn test_case(old: &str, new: &str, expected: Option<(&str, Vec)>) { fn make_schema() -> Schema { Schema::from_iter([ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ]) } let mut schema = make_schema(); @@ -30,9 +30,9 @@ fn test_schema_rename() { Some(( "a", vec![ - Field::new("anton", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("anton".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], )), ); @@ -43,9 +43,9 @@ fn test_schema_rename() { Some(( "b", vec![ - Field::new("a", UInt64), - Field::new("bantam", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("bantam".into(), Int32), + Field::new("c".into(), Int8), ], )), ); @@ -82,9 +82,9 @@ fn test_schema_insert_at_index() { } let schema = Schema::from_iter([ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ]); test_case( @@ -94,10 +94,10 @@ fn test_schema_insert_at_index() { ( None, vec![ - Field::new("new", String), - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("new".into(), String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], ), ); @@ -109,9 +109,9 @@ fn test_schema_insert_at_index() { ( Some(UInt64), vec![ - Field::new("a", String), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], ), ); @@ -123,9 +123,9 @@ fn test_schema_insert_at_index() { ( Some(Int32), vec![ - Field::new("b", String), - Field::new("a", UInt64), - Field::new("c", Int8), + Field::new("b".into(), String), + Field::new("a".into(), UInt64), + Field::new("c".into(), Int8), ], ), ); @@ -137,9 +137,9 @@ fn test_schema_insert_at_index() { ( Some(UInt64), vec![ - Field::new("b", Int32), - Field::new("a", String), - Field::new("c", Int8), + Field::new("b".into(), Int32), + Field::new("a".into(), String), + Field::new("c".into(), Int8), ], ), ); @@ -151,9 +151,9 @@ fn test_schema_insert_at_index() { ( Some(UInt64), vec![ - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("a", String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("a".into(), String), ], ), ); @@ -165,9 +165,9 @@ fn test_schema_insert_at_index() { ( Some(UInt64), vec![ - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("a", String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("a".into(), String), ], ), ); @@ -179,10 +179,10 @@ fn test_schema_insert_at_index() { ( None, vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("new", String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("new".into(), String), ], ), ); @@ -194,9 +194,9 @@ fn test_schema_insert_at_index() { ( Some(Int8), vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), ], ), ); @@ -208,9 +208,9 @@ fn test_schema_insert_at_index() { ( Some(Int8), vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), ], ), ); @@ -239,9 +239,9 @@ fn test_with_column() { } let schema = Schema::from_iter([ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ]); test_case( @@ -251,9 +251,9 @@ fn test_with_column() { ( Some(UInt64), vec![ - Field::new("a", String), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], ), ); @@ -265,9 +265,9 @@ fn test_with_column() { ( Some(Int32), vec![ - Field::new("a", UInt64), - Field::new("b", String), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), String), + Field::new("c".into(), Int8), ], ), ); @@ -279,9 +279,9 @@ fn test_with_column() { ( Some(Int8), vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), ], ), ); @@ -293,10 +293,10 @@ fn test_with_column() { ( None, vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("d", String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), String), ], ), ); @@ -318,14 +318,14 @@ fn test_getters() { } let mut schema = Schema::from_iter([ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ]); test_case!(schema, get, name: "a", &UInt64); test_case!(schema, get_full, name: "a", (0, &"a".into(), &UInt64)); - test_case!(schema, get_field, name: "a", Field::new("a", UInt64)); + test_case!(schema, get_field, name: "a", Field::new("a".into(), UInt64)); test_case!(schema, get_at_index, index: 1, (&"b".into(), &Int32)); test_case!(schema, get_at_index_mut, index: 1, (&mut "b".into(), &mut Int32)); @@ -366,10 +366,10 @@ fn test_removal() { } let schema = Schema::from_iter([ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("d", Float64), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), ]); test_case( @@ -377,14 +377,14 @@ fn test_removal() { "a", Some(UInt64), vec![ - Field::new("d", Float64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("d".into(), Float64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], vec![ - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("d", Float64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), ], ); @@ -393,14 +393,14 @@ fn test_removal() { "b", Some(Int32), vec![ - Field::new("a", UInt64), - Field::new("d", Float64), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("d".into(), Float64), + Field::new("c".into(), Int8), ], vec![ - Field::new("a", UInt64), - Field::new("c", Int8), - Field::new("d", Float64), + Field::new("a".into(), UInt64), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), ], ); @@ -409,14 +409,14 @@ fn test_removal() { "c", Some(Int8), vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("d", Float64), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("d".into(), Float64), ], vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("d", Float64), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("d".into(), Float64), ], ); @@ -425,14 +425,14 @@ fn test_removal() { "d", Some(Float64), vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], ); @@ -441,16 +441,16 @@ fn test_removal() { "NOT_FOUND", None, vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("d", Float64), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), ], vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), - Field::new("d", Float64), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), + Field::new("d".into(), Float64), ], ); } @@ -486,9 +486,9 @@ fn test_set_dtype() { } let schema = Schema::from_iter([ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ]); test_case( @@ -498,9 +498,9 @@ fn test_set_dtype() { ( Some(UInt64), vec![ - Field::new("a", String), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), String), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], ), ); @@ -511,9 +511,9 @@ fn test_set_dtype() { ( Some(Int32), vec![ - Field::new("a", UInt64), - Field::new("b", String), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), String), + Field::new("c".into(), Int8), ], ), ); @@ -524,9 +524,9 @@ fn test_set_dtype() { ( Some(Int8), vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", String), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), String), ], ), ); @@ -537,9 +537,9 @@ fn test_set_dtype() { ( None, vec![ - Field::new("a", UInt64), - Field::new("b", Int32), - Field::new("c", Int8), + Field::new("a".into(), UInt64), + Field::new("b".into(), Int32), + Field::new("c".into(), Int8), ], ), ); diff --git a/crates/polars/tests/it/time/date_range.rs b/crates/polars/tests/it/time/date_range.rs index ff8df835cce2..f9ab68191a8d 100644 --- a/crates/polars/tests/it/time/date_range.rs +++ b/crates/polars/tests/it/time/date_range.rs @@ -14,7 +14,7 @@ fn test_time_units_9413() { .and_hms_opt(0, 0, 0) .unwrap(); let actual = date_range( - "date", + "date".into(), start, stop, Duration::parse("1d"), @@ -35,7 +35,7 @@ Series: 'date' [datetime[ms]] ])"#; assert_eq!(result, expected); let actual = date_range( - "date", + "date".into(), start, stop, Duration::parse("1d"), @@ -56,7 +56,7 @@ Series: 'date' [datetime[μs]] ])"#; assert_eq!(result, expected); let actual = date_range( - "date", + "date".into(), start, stop, Duration::parse("1d"), diff --git a/docs/development/contributing/code-style.md b/docs/development/contributing/code-style.md index b5ce5b2fcd9d..00ad8a8f726e 100644 --- a/docs/development/contributing/code-style.md +++ b/docs/development/contributing/code-style.md @@ -18,7 +18,7 @@ let ca: ChunkedArray = ... let arr: ArrayRef = ... let arr: PrimitiveArray = ... let dtype: DataType = ... -let data_type: ArrowDataType = ... +let dtype: ArrowDataType = ... ``` ### Code example @@ -66,7 +66,7 @@ fn compute_kernel2(arr_1: &PrimitiveArray, arr_2: &PrimitiveArray) -> P where T: Add + NativeType, { - binary(arr_1, arr_2, arr_1.data_type().clone(), |a, b| a + b) + binary(arr_1, arr_2, arr_1.dtype().clone(), |a, b| a + b) } fn compute_chunked_array_2_args( diff --git a/docs/requirements.txt b/docs/requirements.txt index db32e0f1cd39..186eb0e3fcec 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,10 +1,11 @@ +altair pandas pyarrow graphviz +hvplot matplotlib seaborn plotly -altair numba numpy @@ -13,4 +14,4 @@ mkdocs-macros-plugin==1.0.5 mkdocs-redirects==1.2.1 material-plausible-plugin==0.2.0 markdown-exec[ansi]==1.9.3 -pygithub==2.3.0 +pygithub==2.4.0 diff --git a/docs/src/python/user-guide/misc/visualization.py b/docs/src/python/user-guide/misc/visualization.py index f04288cb7812..cd256127f1dd 100644 --- a/docs/src/python/user-guide/misc/visualization.py +++ b/docs/src/python/user-guide/misc/visualization.py @@ -3,30 +3,33 @@ path = "docs/data/iris.csv" -df = pl.scan_csv(path).group_by("species").agg(pl.col("petal_length").mean()).collect() +df = pl.read_csv(path) print(df) # --8<-- [end:dataframe] """ # --8<-- [start:hvplot_show_plot] -df.plot.bar( - x="species", - y="petal_length", +import hvplot.polars +df.hvplot.scatter( + x="sepal_width", + y="sepal_length", + by="species", width=650, ) # --8<-- [end:hvplot_show_plot] """ # --8<-- [start:hvplot_make_plot] -import hvplot +import hvplot.polars -plot = df.plot.bar( - x="species", - y="petal_length", +plot = df.hvplot.scatter( + x="sepal_width", + y="sepal_length", + by="species", width=650, ) -hvplot.save(plot, "docs/images/hvplot_bar.html") -with open("docs/images/hvplot_bar.html", "r") as f: +hvplot.save(plot, "docs/images/hvplot_scatter.html") +with open("docs/images/hvplot_scatter.html", "r") as f: chart_html = f.read() print(f"{chart_html}") # --8<-- [end:hvplot_make_plot] @@ -35,7 +38,12 @@ # --8<-- [start:matplotlib_show_plot] import matplotlib.pyplot as plt -plt.bar(x=df["species"], height=df["petal_length"]) +fig, ax = plt.subplots() +ax.scatter( + x=df["sepal_width"], + y=df["sepal_length"], + c=df["species"].cast(pl.Categorical).to_physical(), +) # --8<-- [end:matplotlib_show_plot] """ @@ -44,9 +52,14 @@ import matplotlib.pyplot as plt -plt.bar(x=df["species"], height=df["petal_length"]) -plt.savefig("docs/images/matplotlib_bar.png") -with open("docs/images/matplotlib_bar.png", "rb") as f: +fig, ax = plt.subplots() +ax.scatter( + x=df["sepal_width"], + y=df["sepal_length"], + c=df["species"].cast(pl.Categorical).to_physical(), +) +fig.savefig("docs/images/matplotlib_scatter.png") +with open("docs/images/matplotlib_scatter.png", "rb") as f: png = base64.b64encode(f.read()).decode() print(f'') # --8<-- [end:matplotlib_make_plot] @@ -54,24 +67,28 @@ """ # --8<-- [start:seaborn_show_plot] import seaborn as sns -sns.barplot( +sns.scatterplot( df, - x="species", - y="petal_length", + x="sepal_width", + y="sepal_length", + hue="species", ) # --8<-- [end:seaborn_show_plot] """ # --8<-- [start:seaborn_make_plot] import seaborn as sns +import matplotlib.pyplot as plt -sns.barplot( +fig, ax = plt.subplots() +ax = sns.scatterplot( df, - x="species", - y="petal_length", + x="sepal_width", + y="sepal_length", + hue="species", ) -plt.savefig("docs/images/seaborn_bar.png") -with open("docs/images/seaborn_bar.png", "rb") as f: +fig.savefig("docs/images/seaborn_scatter.png") +with open("docs/images/seaborn_scatter.png", "rb") as f: png = base64.b64encode(f.read()).decode() print(f'') # --8<-- [end:seaborn_make_plot] @@ -80,11 +97,12 @@ # --8<-- [start:plotly_show_plot] import plotly.express as px -px.bar( +px.scatter( df, - x="species", - y="petal_length", - width=400, + x="sepal_width", + y="sepal_length", + color="species", + width=650, ) # --8<-- [end:plotly_show_plot] """ @@ -92,39 +110,47 @@ # --8<-- [start:plotly_make_plot] import plotly.express as px -fig = px.bar( +fig = px.scatter( df, - x="species", - y="petal_length", + x="sepal_width", + y="sepal_length", + color="species", width=650, ) -fig.write_html("docs/images/plotly_bar.html", full_html=False, include_plotlyjs="cdn") -with open("docs/images/plotly_bar.html", "r") as f: +fig.write_html( + "docs/images/plotly_scatter.html", full_html=False, include_plotlyjs="cdn" +) +with open("docs/images/plotly_scatter.html", "r") as f: chart_html = f.read() print(f"{chart_html}") # --8<-- [end:plotly_make_plot] """ # --8<-- [start:altair_show_plot] -import altair as alt - -alt.Chart(df, width=700).mark_bar().encode(x="species:N", y="petal_length:Q") +( + df.plot.point( + x="sepal_length", + y="sepal_width", + color="species", + ) + .properties(width=500) + .configure_scale(zero=False) +) # --8<-- [end:altair_show_plot] """ # --8<-- [start:altair_make_plot] -import altair as alt - chart = ( - alt.Chart(df, width=600) - .mark_bar() - .encode( - x="species:N", - y="petal_length:Q", + df.plot.point( + x="sepal_length", + y="sepal_width", + color="species", ) + .properties(width=500) + .configure_scale(zero=False) ) -chart.save("docs/images/altair_bar.html") -with open("docs/images/altair_bar.html", "r") as f: +chart.save("docs/images/altair_scatter.html") +with open("docs/images/altair_scatter.html", "r") as f: chart_html = f.read() print(f"{chart_html}") # --8<-- [end:altair_make_plot] diff --git a/docs/src/rust/Cargo.toml b/docs/src/rust/Cargo.toml index c1897560eb59..c99561340b12 100644 --- a/docs/src/rust/Cargo.toml +++ b/docs/src/rust/Cargo.toml @@ -134,6 +134,7 @@ required-features = ["polars/lazy", "polars/asof_join"] [[bin]] name = "user-guide-transformations-unpivot" path = "user-guide/transformations/unpivot.rs" +required-features = ["polars/pivot"] [[bin]] name = "user-guide-transformations-pivot" path = "user-guide/transformations/pivot.rs" diff --git a/docs/src/rust/user-guide/concepts/data-structures.rs b/docs/src/rust/user-guide/concepts/data-structures.rs index 2334f7718569..b8a4b70daa14 100644 --- a/docs/src/rust/user-guide/concepts/data-structures.rs +++ b/docs/src/rust/user-guide/concepts/data-structures.rs @@ -2,7 +2,7 @@ fn main() { // --8<-- [start:series] use polars::prelude::*; - let s = Series::new("a", &[1, 2, 3, 4, 5]); + let s = Series::new("a".into(), &[1, 2, 3, 4, 5]); println!("{}", s); // --8<-- [end:series] @@ -39,7 +39,7 @@ fn main() { // --8<-- [end:tail] // --8<-- [start:sample] - let n = Series::new("", &[2]); + let n = Series::new("".into(), &[2]); let sampled_df = df.sample_n(&n, false, false, None).unwrap(); println!("{}", sampled_df); diff --git a/docs/src/rust/user-guide/expressions/aggregation.rs b/docs/src/rust/user-guide/expressions/aggregation.rs index fe5e13a38940..9436565330bf 100644 --- a/docs/src/rust/user-guide/expressions/aggregation.rs +++ b/docs/src/rust/user-guide/expressions/aggregation.rs @@ -8,7 +8,7 @@ fn main() -> Result<(), Box> { let url = "https://theunitedstates.io/congress-legislators/legislators-historical.csv"; - let mut schema = Schema::new(); + let mut schema = Schema::default(); schema.with_column( "first_name".into(), DataType::Categorical(None, Default::default()), @@ -116,7 +116,7 @@ fn main() -> Result<(), Box> { compute_age() .filter(col("gender").eq(lit(gender))) .mean() - .alias(&format!("avg {} birthday", gender)) + .alias(format!("avg {} birthday", gender)) } let df = dataset diff --git a/docs/src/rust/user-guide/expressions/casting.rs b/docs/src/rust/user-guide/expressions/casting.rs index b18ca19022df..85824afc3198 100644 --- a/docs/src/rust/user-guide/expressions/casting.rs +++ b/docs/src/rust/user-guide/expressions/casting.rs @@ -135,7 +135,7 @@ fn main() -> Result<(), Box> { use chrono::prelude::*; let date = polars::time::date_range( - "date", + "date".into(), NaiveDate::from_ymd_opt(2022, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -152,7 +152,7 @@ fn main() -> Result<(), Box> { .cast(&DataType::Date)?; let datetime = polars::time::date_range( - "datetime", + "datetime".into(), NaiveDate::from_ymd_opt(2022, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -185,7 +185,7 @@ fn main() -> Result<(), Box> { // --8<-- [start:dates2] let date = polars::time::date_range( - "date", + "date".into(), NaiveDate::from_ymd_opt(2022, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) diff --git a/docs/src/rust/user-guide/expressions/column-selections.rs b/docs/src/rust/user-guide/expressions/column-selections.rs index f3cacebd8c0c..c0f3f35ac3b0 100644 --- a/docs/src/rust/user-guide/expressions/column-selections.rs +++ b/docs/src/rust/user-guide/expressions/column-selections.rs @@ -9,14 +9,14 @@ fn main() -> Result<(), Box> { let df = df!( "id" => &[9, 4, 2], "place" => &["Mars", "Earth", "Saturn"], - "date" => date_range("date", + "date" => date_range("date".into(), NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), Duration::parse("1d"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, "sales" => &[33.4, 2142134.1, 44.7], "has_people" => &[false, true, false], - "logged_at" => date_range("logged_at", + "logged_at" => date_range("logged_at".into(), NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 2).unwrap(), Duration::parse("1s"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, )? - .with_row_index("index", None)?; + .with_row_index("index".into(), None)?; println!("{}", &df); // --8<-- [end:selectors_df] diff --git a/docs/src/rust/user-guide/expressions/lists.rs b/docs/src/rust/user-guide/expressions/lists.rs index 530ae4d79892..9ce160cd58aa 100644 --- a/docs/src/rust/user-guide/expressions/lists.rs +++ b/docs/src/rust/user-guide/expressions/lists.rs @@ -134,14 +134,17 @@ fn main() -> Result<(), Box> { // --8<-- [start:array_df] let mut col1: ListPrimitiveChunkedBuilder = - ListPrimitiveChunkedBuilder::new("Array_1", 8, 8, DataType::Int32); + ListPrimitiveChunkedBuilder::new("Array_1".into(), 8, 8, DataType::Int32); col1.append_slice(&[1, 3]); col1.append_slice(&[2, 5]); let mut col2: ListPrimitiveChunkedBuilder = - ListPrimitiveChunkedBuilder::new("Array_2", 8, 8, DataType::Int32); + ListPrimitiveChunkedBuilder::new("Array_2".into(), 8, 8, DataType::Int32); col2.append_slice(&[1, 7, 3]); col2.append_slice(&[8, 1, 0]); - let array_df = DataFrame::new([col1.finish(), col2.finish()].into())?; + let array_df = DataFrame::new(vec![ + col1.finish().into_series(), + col2.finish().into_series(), + ])?; println!("{}", &array_df); // --8<-- [end:array_df] diff --git a/docs/src/rust/user-guide/expressions/structs.rs b/docs/src/rust/user-guide/expressions/structs.rs index 0722b2aac5ee..25ed02daf827 100644 --- a/docs/src/rust/user-guide/expressions/structs.rs +++ b/docs/src/rust/user-guide/expressions/structs.rs @@ -17,7 +17,7 @@ fn main() -> Result<(), Box> { let out = ratings .clone() .lazy() - .select([col("Theatre").value_counts(true, true, "count".to_string(), false)]) + .select([col("Theatre").value_counts(true, true, "count", false)]) .collect()?; println!("{}", &out); // --8<-- [end:state_value_counts] @@ -26,7 +26,7 @@ fn main() -> Result<(), Box> { let out = ratings .clone() .lazy() - .select([col("Theatre").value_counts(true, true, "count".to_string(), false)]) + .select([col("Theatre").value_counts(true, true, "count", false)]) .unnest(["Theatre"]) .collect()?; println!("{}", &out); @@ -39,7 +39,7 @@ fn main() -> Result<(), Box> { "Theatre" => &["NE", "ME"], "Avg_Rating" => &[4.5, 4.9], )? - .into_struct("ratings") + .into_struct("ratings".into()) .into_series(); println!("{}", &rating_series); // // --8<-- [end:series_struct] @@ -54,7 +54,7 @@ fn main() -> Result<(), Box> { .lazy() .select([col("ratings") .struct_() - .rename_fields(["Film".into(), "State".into(), "Value".into()].to_vec())]) + .rename_fields(["Film", "State", "Value"].to_vec())]) .unnest(["ratings"]) .collect()?; diff --git a/docs/src/rust/user-guide/transformations/time-series/parsing.rs b/docs/src/rust/user-guide/transformations/time-series/parsing.rs index a58b5cf2850e..b12c488d0108 100644 --- a/docs/src/rust/user-guide/transformations/time-series/parsing.rs +++ b/docs/src/rust/user-guide/transformations/time-series/parsing.rs @@ -60,13 +60,13 @@ fn main() -> Result<(), Box> { Some(TimeUnit::Microseconds), None, StrptimeOptions { - format: Some("%Y-%m-%dT%H:%M:%S%z".to_string()), + format: Some("%Y-%m-%dT%H:%M:%S%z".into()), ..Default::default() }, lit("raise"), ) .dt() - .convert_time_zone("Europe/Brussels".to_string()); + .convert_time_zone("Europe/Brussels".into()); let mixed_parsed = df!("date" => &data)?.lazy().select([q]).collect()?; println!("{}", &mixed_parsed); diff --git a/docs/src/rust/user-guide/transformations/time-series/resampling.rs b/docs/src/rust/user-guide/transformations/time-series/resampling.rs index e1cd4baa1682..dec19f65fc26 100644 --- a/docs/src/rust/user-guide/transformations/time-series/resampling.rs +++ b/docs/src/rust/user-guide/transformations/time-series/resampling.rs @@ -6,7 +6,7 @@ use polars::prelude::*; fn main() -> Result<(), Box> { // --8<-- [start:df] let time = polars::time::date_range( - "time", + "time".into(), NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(0, 0, 0) diff --git a/docs/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/src/rust/user-guide/transformations/time-series/rolling.rs index 559bf0bc2fed..19b57f2d0c33 100644 --- a/docs/src/rust/user-guide/transformations/time-series/rolling.rs +++ b/docs/src/rust/user-guide/transformations/time-series/rolling.rs @@ -45,7 +45,7 @@ fn main() -> Result<(), Box> { // --8<-- [start:group_by_dyn] let time = polars::time::date_range( - "time", + "time".into(), NaiveDate::from_ymd_opt(2021, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -106,7 +106,7 @@ fn main() -> Result<(), Box> { // --8<-- [start:group_by_roll] let time = polars::time::date_range( - "time", + "time".into(), NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(0, 0, 0) diff --git a/docs/src/rust/user-guide/transformations/time-series/timezones.rs b/docs/src/rust/user-guide/transformations/time-series/timezones.rs index 4924338b4f86..489786cb844e 100644 --- a/docs/src/rust/user-guide/transformations/time-series/timezones.rs +++ b/docs/src/rust/user-guide/transformations/time-series/timezones.rs @@ -5,7 +5,7 @@ use polars::prelude::*; fn main() -> Result<(), Box> { // --8<-- [start:example] let ts = ["2021-03-27 03:00", "2021-03-28 03:00"]; - let tz_naive = Series::new("tz_naive", &ts); + let tz_naive = Series::new("tz_naive".into(), &ts); let time_zones_df = DataFrame::new(vec![tz_naive])? .lazy() .select([col("tz_naive").str().to_datetime( @@ -16,7 +16,7 @@ fn main() -> Result<(), Box> { )]) .with_columns([col("tz_naive") .dt() - .replace_time_zone(Some("UTC".to_string()), lit("raise"), NonExistent::Raise) + .replace_time_zone(Some("UTC".into()), lit("raise"), NonExistent::Raise) .alias("tz_aware")]) .collect()?; @@ -30,14 +30,14 @@ fn main() -> Result<(), Box> { col("tz_aware") .dt() .replace_time_zone( - Some("Europe/Brussels".to_string()), + Some("Europe/Brussels".into()), lit("raise"), NonExistent::Raise, ) .alias("replace time zone"), col("tz_aware") .dt() - .convert_time_zone("Asia/Kathmandu".to_string()) + .convert_time_zone("Asia/Kathmandu".into()) .alias("convert time zone"), col("tz_aware") .dt() diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md index 9e79e5a2e3db..e679b09ee180 100644 --- a/docs/user-guide/expressions/plugins.md +++ b/docs/user-guide/expressions/plugins.md @@ -103,11 +103,12 @@ import polars as pl from polars.plugins import register_plugin_function from polars._typing import IntoExpr +PLUGIN_PATH = Path(__file__).parent def pig_latinnify(expr: IntoExpr) -> pl.Expr: """Pig-latinnify expression.""" return register_plugin_function( - plugin_path=Path(__file__).parent, + plugin_path=PLUGIN_PATH, function_name="pig_latinnify", args=expr, is_elementwise=True, @@ -190,7 +191,7 @@ def append_args( This example shows how arguments other than `Series` can be used. """ return register_plugin_function( - plugin_path=Path(__file__).parent, + plugin_path=PLUGIN_PATH, function_name="append_kwargs", args=expr, kwargs={ diff --git a/docs/user-guide/io/hugging-face.md b/docs/user-guide/io/hugging-face.md index 1a94210d657b..16f705ae75fb 100644 --- a/docs/user-guide/io/hugging-face.md +++ b/docs/user-guide/io/hugging-face.md @@ -65,7 +65,7 @@ See this file at [https://huggingface.co/datasets/nameexhaustion/polars-docs/blo #### Parquet -{{code_block('user-guide/io/hugging-face','scan_parquet_hive',['scan_parquet'])}} +{{code_block('user-guide/io/hugging-face','scan_parquet_hive_repr',['scan_parquet'])}} ```python exec="on" result="text" session="user-guide/io/hugging-face" --8<-- "python/user-guide/io/hugging-face.py:scan_parquet_hive_repr" diff --git a/docs/user-guide/misc/multiprocessing.md b/docs/user-guide/misc/multiprocessing.md index 4973da8c0155..d46a96a52bc5 100644 --- a/docs/user-guide/misc/multiprocessing.md +++ b/docs/user-guide/misc/multiprocessing.md @@ -52,7 +52,6 @@ Consider the example below, which is a slightly modified example posted on the [ {{code_block('user-guide/misc/multiprocess','example1',[])}} Using `fork` as the method, instead of `spawn`, will cause a dead lock. -Please note: Polars will not even start and raise the error on multiprocessing method being set wrong, but if the check had not been there, the deadlock would exist. The fork method is equivalent to calling `os.fork()`, which is a system call as defined in [the POSIX standard](https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html): diff --git a/docs/user-guide/misc/visualization.md b/docs/user-guide/misc/visualization.md index 88dcd83a18a6..3f7574c07a2e 100644 --- a/docs/user-guide/misc/visualization.md +++ b/docs/user-guide/misc/visualization.md @@ -2,7 +2,8 @@ Data in a Polars `DataFrame` can be visualized using common visualization libraries. -We illustrate plotting capabilities using the Iris dataset. We scan a CSV and then do a group-by on the `species` column and get the mean of the `petal_length`. +We illustrate plotting capabilities using the Iris dataset. We read a CSV and then +plot one column against another, colored by a yet another column. {{code_block('user-guide/misc/visualization','dataframe',[])}} @@ -10,9 +11,39 @@ We illustrate plotting capabilities using the Iris dataset. We scan a CSV and th --8<-- "python/user-guide/misc/visualization.py:dataframe" ``` -## Built-in plotting with hvPlot +## Built-in plotting with Altair -Polars has a `plot` method to create interactive plots using [hvPlot](https://hvplot.holoviz.org/). +Polars has a `plot` method to create plots using [Altair](https://altair-viz.github.io/): + +{{code_block('user-guide/misc/visualization','altair_show_plot',[])}} + +```python exec="on" session="user-guide/misc/visualization" +--8<-- "python/user-guide/misc/visualization.py:altair_make_plot" +``` + +This is shorthand for: + +```python +import altair as alt + +( + alt.Chart(df).mark_point().encode( + x="sepal_length", + y="sepal_width", + color="species", + ) + .properties(width=500) + .configure_scale(zero=False) +) +``` + +and is only provided for convenience, and to signal that Altair is known to work well with +Polars. + +## hvPlot + +If you import `hvplot.polars`, then it registers a `hvplot` +method which you can use to create interactive plots using [hvPlot](https://hvplot.holoviz.org/). {{code_block('user-guide/misc/visualization','hvplot_show_plot',[])}} @@ -22,8 +53,12 @@ Polars has a `plot` method to create interactive plots using [hvPlot](https://hv ## Matplotlib -To create a bar chart we can pass columns of a `DataFrame` directly to Matplotlib as a `Series` for each column. Matplotlib does not have explicit support for Polars objects but Matplotlib can accept a Polars `Series` because it can convert each Series to a numpy array, which is zero-copy for numeric -data without null values. +To create a scatter plot we can pass columns of a `DataFrame` directly to Matplotlib as a `Series` for each column. +Matplotlib does not have explicit support for Polars objects but can accept a Polars `Series` by +converting it to a NumPy array (which is zero-copy for numeric data without null values). + +Note that because the column `'species'` isn't numeric, we need to first convert it to numeric values so that +it can be passed as an argument to `c`. {{code_block('user-guide/misc/visualization','matplotlib_show_plot',[])}} @@ -31,9 +66,10 @@ data without null values. --8<-- "python/user-guide/misc/visualization.py:matplotlib_make_plot" ``` -## Seaborn, Plotly & Altair +## Seaborn and Plotly -[Seaborn](https://seaborn.pydata.org/), [Plotly](https://plotly.com/) & [Altair](https://altair-viz.github.io/) can accept a Polars `DataFrame` by leveraging the [dataframe interchange protocol](https://data-apis.org/dataframe-api/), which offers zero-copy conversion where possible. +[Seaborn](https://seaborn.pydata.org/) and [Plotly](https://plotly.com/) can accept a Polars `DataFrame` by leveraging the [dataframe interchange protocol](https://data-apis.org/dataframe-api/), which offers zero-copy conversion where possible. Note +that the protocol does not support all Polars data types (e.g. `List`) so your mileage may vary here. ### Seaborn @@ -50,11 +86,3 @@ data without null values. ```python exec="on" session="user-guide/misc/visualization" --8<-- "python/user-guide/misc/visualization.py:plotly_make_plot" ``` - -### Altair - -{{code_block('user-guide/misc/visualization','altair_show_plot',[])}} - -```python exec="on" session="user-guide/misc/visualization" ---8<-- "python/user-guide/misc/visualization.py:altair_make_plot" -``` diff --git a/examples/datasets/tpc_heads/customer.feather b/examples/datasets/pds_heads/customer.feather similarity index 100% rename from examples/datasets/tpc_heads/customer.feather rename to examples/datasets/pds_heads/customer.feather diff --git a/examples/datasets/tpc_heads/lineitem.feather b/examples/datasets/pds_heads/lineitem.feather similarity index 100% rename from examples/datasets/tpc_heads/lineitem.feather rename to examples/datasets/pds_heads/lineitem.feather diff --git a/examples/datasets/tpc_heads/nation.feather b/examples/datasets/pds_heads/nation.feather similarity index 100% rename from examples/datasets/tpc_heads/nation.feather rename to examples/datasets/pds_heads/nation.feather diff --git a/examples/datasets/tpc_heads/orders.feather b/examples/datasets/pds_heads/orders.feather similarity index 100% rename from examples/datasets/tpc_heads/orders.feather rename to examples/datasets/pds_heads/orders.feather diff --git a/examples/datasets/tpc_heads/part.feather b/examples/datasets/pds_heads/part.feather similarity index 100% rename from examples/datasets/tpc_heads/part.feather rename to examples/datasets/pds_heads/part.feather diff --git a/examples/datasets/tpc_heads/partsupp.feather b/examples/datasets/pds_heads/partsupp.feather similarity index 100% rename from examples/datasets/tpc_heads/partsupp.feather rename to examples/datasets/pds_heads/partsupp.feather diff --git a/examples/datasets/tpc_heads/region.feather b/examples/datasets/pds_heads/region.feather similarity index 100% rename from examples/datasets/tpc_heads/region.feather rename to examples/datasets/pds_heads/region.feather diff --git a/examples/datasets/tpc_heads/supplier.feather b/examples/datasets/pds_heads/supplier.feather similarity index 100% rename from examples/datasets/tpc_heads/supplier.feather rename to examples/datasets/pds_heads/supplier.feather diff --git a/examples/python_rust_compiled_function/src/ffi.rs b/examples/python_rust_compiled_function/src/ffi.rs index 22222e8e20f8..3597d1f83a03 100644 --- a/examples/python_rust_compiled_function/src/ffi.rs +++ b/examples/python_rust_compiled_function/src/ffi.rs @@ -24,7 +24,7 @@ fn array_to_rust(arrow_array: &Bound) -> PyResult { unsafe { let field = ffi::import_field_from_c(schema.as_ref()).unwrap(); - let array = ffi::import_array_from_c(*array, field.data_type).unwrap(); + let array = ffi::import_array_from_c(*array, field.dtype).unwrap(); Ok(array) } } @@ -33,7 +33,7 @@ fn array_to_rust(arrow_array: &Bound) -> PyResult { pub(crate) fn to_py_array(py: Python, pyarrow: &Bound, array: ArrayRef) -> PyResult { let schema = Box::new(ffi::export_field_to_c(&ArrowField::new( "", - array.data_type().clone(), + array.dtype().clone(), true, ))); let array = Box::new(ffi::export_array_to_c(array)); diff --git a/examples/read_csv/src/main.rs b/examples/read_csv/src/main.rs index aa9188f19409..877fc6483635 100644 --- a/examples/read_csv/src/main.rs +++ b/examples/read_csv/src/main.rs @@ -2,7 +2,7 @@ use polars::io::mmap::MmapBytesReader; use polars::prelude::*; fn main() -> PolarsResult<()> { - let file = std::fs::File::open("/home/ritchie46/Downloads/tpch/tables_scale_100/lineitem.tbl") + let file = std::fs::File::open("/home/ritchie46/Downloads/pdsh/tables_scale_100/lineitem.tbl") .unwrap(); let file = Box::new(file) as Box; let _df = CsvReader::new(file) diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 7998cf714b81..a42c643516ea 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "1.4.1" +version = "1.6.0" edition = "2021" [lib] @@ -8,104 +8,9 @@ name = "polars" crate-type = ["cdylib"] [dependencies] -polars-core = { workspace = true, features = ["python"] } -polars-error = { workspace = true } -polars-io = { workspace = true } -polars-lazy = { workspace = true, features = ["python"] } -polars-ops = { workspace = true } -polars-parquet = { workspace = true, optional = true } -polars-plan = { workspace = true } -polars-time = { workspace = true } -polars-utils = { workspace = true } - -# TODO! remove this once truly activated. This is required to make sdist building work -polars-stream = { workspace = true } - -ahash = { workspace = true } -arboard = { workspace = true, optional = true } -bytemuck = { workspace = true } -ciborium = { workspace = true } -either = { workspace = true } -itoa = { workspace = true } -libc = "0.2" -ndarray = { workspace = true } -num-traits = { workspace = true } -# TODO: Pin to released version once NumPy 2.0 support is merged -# https://github.com/PyO3/rust-numpy/issues/409 -numpy = { git = "https://github.com/stinodego/rust-numpy.git", rev = "9ba9962ae57ba26e35babdce6f179edf5fe5b9c8", default-features = false } -once_cell = { workspace = true } +libc = { workspace = true } +polars-python = { workspace = true, features = ["pymethods"] } pyo3 = { workspace = true, features = ["abi3-py38", "chrono", "extension-module", "multiple-pymethods"] } -recursive = { workspace = true } -serde_json = { workspace = true, optional = true } -smartstring = { workspace = true } -thiserror = { workspace = true } - -[dependencies.polars] -workspace = true -features = [ - "abs", - "approx_unique", - "array_any_all", - "arg_where", - "business", - "concat_str", - "cum_agg", - "cumulative_eval", - "dataframe_arithmetic", - "month_start", - "month_end", - "offset_by", - "diagonal_concat", - "diff", - "dot_diagram", - "dot_product", - "dtype-categorical", - "dtype-full", - "dynamic_group_by", - "ewma", - "ewma_by", - "fmt", - "fused", - "interpolate", - "interpolate_by", - "is_first_distinct", - "is_last_distinct", - "is_unique", - "is_between", - "lazy", - "list_eval", - "list_to_struct", - "array_to_struct", - "log", - "mode", - "moment", - "ndarray", - "partition_by", - "product", - "random", - "range", - "rank", - "reinterpret", - "replace", - "rolling_window", - "rolling_window_by", - "round_series", - "row_hash", - "rows", - "semi_anti_join", - "serde-lazy", - "string_encoding", - "string_reverse", - "string_to_integer", - "string_pad", - "strings", - "temporal", - "to_dummies", - "true_div", - "unique_counts", - "zip_with", - "cov", -] [build-dependencies] built = { version = "0.7", features = ["chrono", "git2", "cargo-lock"], optional = true } @@ -113,149 +18,95 @@ built = { version = "0.7", features = ["chrono", "git2", "cargo-lock"], optional [target.'cfg(all(any(not(target_family = "unix"), allocator = "mimalloc"), not(allocator = "default")))'.dependencies] mimalloc = { version = "0.1", default-features = false } -[target.'cfg(all(target_family = "unix", not(allocator = "mimalloc"), not(allocator = "default")))'.dependencies] +# Feature background_threads is unsupported on MacOS (https://github.com/jemalloc/jemalloc/issues/843). +[target.'cfg(all(target_family = "unix", not(target_os = "macos"), not(allocator = "mimalloc"), not(allocator = "default")))'.dependencies] +jemallocator = { version = "0.5", features = ["disable_initial_exec_tls", "background_threads"] } + +[target.'cfg(all(target_family = "unix", target_os = "macos", not(allocator = "mimalloc"), not(allocator = "default")))'.dependencies] jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } [features] -# Features below are only there to enable building a slim binary during development. -avro = ["polars/avro"] -parquet = ["polars/parquet", "polars-parquet"] -ipc = ["polars/ipc"] -ipc_streaming = ["polars/ipc_streaming"] -is_in = ["polars/is_in"] -json = ["polars/serde", "serde_json", "polars/json"] -trigonometry = ["polars/trigonometry"] -sign = ["polars/sign"] -asof_join = ["polars/asof_join"] -cross_join = ["polars/cross_join"] -pct_change = ["polars/pct_change"] -repeat_by = ["polars/repeat_by"] -# also includes simd -nightly = ["polars/nightly"] -streaming = ["polars/streaming"] -meta = ["polars/meta"] -search_sorted = ["polars/search_sorted"] -decompress = ["polars/decompress-fast"] -regex = ["polars/regex"] -csv = ["polars/csv"] -clipboard = ["arboard"] -extract_jsonpath = ["polars/extract_jsonpath"] -pivot = ["polars/pivot"] -top_k = ["polars/top_k"] -propagate_nans = ["polars/propagate_nans"] -sql = ["polars/sql"] +# Features used in this crate build_info = ["dep:built"] -performant = ["polars/performant"] -timezones = ["polars/timezones"] -cse = ["polars/cse"] -merge_sorted = ["polars/merge_sorted"] -list_gather = ["polars/list_gather"] -list_count = ["polars/list_count"] -array_count = ["polars/array_count", "polars/dtype-array"] -binary_encoding = ["polars/binary_encoding"] -list_sets = ["polars-lazy/list_sets"] -list_any_all = ["polars/list_any_all"] -array_any_all = ["polars/array_any_all", "polars/dtype-array"] -list_drop_nulls = ["polars/list_drop_nulls"] -list_sample = ["polars/list_sample"] -cutqcut = ["polars/cutqcut"] -rle = ["polars/rle"] -extract_groups = ["polars/extract_groups"] -ffi_plugin = ["polars-plan/ffi_plugin"] -cloud = ["polars/cloud", "polars/aws", "polars/gcp", "polars/azure", "polars/http"] -peaks = ["polars/peaks"] -hist = ["polars/hist"] -find_many = ["polars/find_many"] -new_streaming = ["polars-lazy/new_streaming"] - -dtype-i8 = [] -dtype-i16 = [] -dtype-u8 = [] -dtype-u16 = [] -dtype-array = [] -object = ["polars/object"] - -dtypes = [ - "dtype-array", - "dtype-i16", - "dtype-i8", - "dtype-u16", - "dtype-u8", - "object", -] - -operations = [ - "array_any_all", - "array_count", - "is_in", - "repeat_by", - "trigonometry", - "sign", - "performant", - "list_gather", - "list_count", - "list_sets", - "list_any_all", - "list_drop_nulls", - "list_sample", - "cutqcut", - "rle", - "extract_groups", - "pivot", - "extract_jsonpath", - "asof_join", - "cross_join", - "pct_change", - "search_sorted", - "merge_sorted", - "top_k", - "propagate_nans", - "timezones", - "peaks", - "hist", - "find_many", -] - -io = [ - "json", - "parquet", - "ipc", - "ipc_streaming", - "avro", - "csv", - "cloud", - "clipboard", -] - -optimizations = [ - "cse", - "polars/fused", - "streaming", -] - -polars_cloud = ["polars/polars_cloud"] +ffi_plugin = ["polars-python/ffi_plugin"] +csv = ["polars-python/csv"] +polars_cloud = ["polars-python/polars_cloud"] +object = ["polars-python/object"] +clipboard = ["polars-python/clipboard"] +sql = ["polars-python/sql"] +trigonometry = ["polars-python/trigonometry"] +parquet = ["polars-python/parquet"] +ipc = ["polars-python/ipc"] + +# Features passed through to the polars-python crate +avro = ["polars-python/avro"] +ipc_streaming = ["polars-python/ipc_streaming"] +is_in = ["polars-python/is_in"] +json = ["polars-python/json"] +sign = ["polars-python/sign"] +asof_join = ["polars-python/asof_join"] +cross_join = ["polars-python/cross_join"] +pct_change = ["polars-python/pct_change"] +repeat_by = ["polars-python/repeat_by"] +# also includes simd +nightly = ["polars-python/nightly"] +streaming = ["polars-python/streaming"] +meta = ["polars-python/meta"] +search_sorted = ["polars-python/search_sorted"] +decompress = ["polars-python/decompress"] +regex = ["polars-python/regex"] +extract_jsonpath = ["polars-python/extract_jsonpath"] +pivot = ["polars-python/pivot"] +top_k = ["polars-python/top_k"] +propagate_nans = ["polars-python/propagate_nans"] +performant = ["polars-python/performant"] +timezones = ["polars-python/timezones"] +cse = ["polars-python/cse"] +merge_sorted = ["polars-python/merge_sorted"] +list_gather = ["polars-python/list_gather"] +list_count = ["polars-python/list_count"] +array_count = ["polars-python/array_count"] +binary_encoding = ["polars-python/binary_encoding"] +list_sets = ["polars-python/list_sets"] +list_any_all = ["polars-python/list_any_all"] +array_any_all = ["polars-python/array_any_all"] +list_drop_nulls = ["polars-python/list_drop_nulls"] +list_sample = ["polars-python/list_sample"] +cutqcut = ["polars-python/cutqcut"] +rle = ["polars-python/rle"] +extract_groups = ["polars-python/extract_groups"] +cloud = ["polars-python/cloud"] +peaks = ["polars-python/peaks"] +hist = ["polars-python/hist"] +find_many = ["polars-python/find_many"] +new_streaming = ["polars-python/new_streaming"] + +dtype-i8 = ["polars-python/dtype-i8"] +dtype-i16 = ["polars-python/dtype-i16"] +dtype-u8 = ["polars-python/dtype-u8"] +dtype-u16 = ["polars-python/dtype-u16"] +dtype-array = ["polars-python/dtype-array"] + +dtypes = ["polars-python/dtypes"] + +operations = ["polars-python/operations"] + +io = ["polars-python/io"] + +optimizations = ["polars-python/optimizations"] all = [ - "optimizations", - "io", - "operations", - "dtypes", - "meta", - "decompress", - "regex", "build_info", - "sql", - "binary_encoding", "ffi_plugin", + "csv", "polars_cloud", - # "new_streaming", + "object", + "clipboard", + "sql", + "trigonometry", + "parquet", + "ipc", + "polars-python/all", ] -# we cannot conditionally activate simd -# https://github.com/rust-lang/cargo/issues/1197 -# so we have an indirection and compile -# with --no-default-features --features=all for targets without simd -default = [ - "all", - "nightly", -] +default = ["all", "nightly"] diff --git a/py-polars/Makefile b/py-polars/Makefile index 7e273b14914c..3c98adab08cb 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -113,7 +113,7 @@ clean: ## Clean up caches and build artifacts @rm -rf .mypy_cache/ @rm -rf .pytest_cache/ @$(VENV_BIN)/ruff clean - @rm -rf tests/data/tpch/sf* + @rm -rf tests/data/pdsh/sf* @rm -f .coverage @rm -f coverage.xml @rm -f polars/polars.abi3.so diff --git a/py-polars/docs/source/reference/dataframe/export.rst b/py-polars/docs/source/reference/dataframe/export.rst index 0347b7429da0..8ebb005221eb 100644 --- a/py-polars/docs/source/reference/dataframe/export.rst +++ b/py-polars/docs/source/reference/dataframe/export.rst @@ -8,6 +8,8 @@ Export DataFrame data to other formats: .. autosummary:: :toctree: api/ + DataFrame.__array__ + DataFrame.__arrow_c_stream__ DataFrame.__dataframe__ DataFrame.to_arrow DataFrame.to_dict diff --git a/py-polars/docs/source/reference/dataframe/modify_select.rst b/py-polars/docs/source/reference/dataframe/modify_select.rst index 11042e70c7bd..b3a3d024ebd2 100644 --- a/py-polars/docs/source/reference/dataframe/modify_select.rst +++ b/py-polars/docs/source/reference/dataframe/modify_select.rst @@ -6,6 +6,7 @@ Manipulation/selection .. autosummary:: :toctree: api/ + DataFrame.__getitem__ DataFrame.bottom_k DataFrame.cast DataFrame.clear @@ -34,6 +35,7 @@ Manipulation/selection DataFrame.iter_slices DataFrame.join DataFrame.join_asof + DataFrame.join_where DataFrame.limit DataFrame.melt DataFrame.merge_sorted diff --git a/py-polars/docs/source/reference/expressions/aggregation.rst b/py-polars/docs/source/reference/expressions/aggregation.rst index d57b76618b31..1162f9a25ac3 100644 --- a/py-polars/docs/source/reference/expressions/aggregation.rst +++ b/py-polars/docs/source/reference/expressions/aggregation.rst @@ -7,6 +7,9 @@ Aggregation :toctree: api/ Expr.agg_groups + Expr.all + Expr.any + Expr.approx_n_unique Expr.arg_max Expr.arg_min Expr.count @@ -18,8 +21,10 @@ Aggregation Expr.mean Expr.median Expr.min + Expr.n_unique Expr.nan_max Expr.nan_min + Expr.null_count Expr.product Expr.quantile Expr.std diff --git a/py-polars/docs/source/reference/expressions/array.rst b/py-polars/docs/source/reference/expressions/array.rst index f25f2a30bbfd..1573c478920c 100644 --- a/py-polars/docs/source/reference/expressions/array.rst +++ b/py-polars/docs/source/reference/expressions/array.rst @@ -9,27 +9,27 @@ The following methods are available under the `expr.arr` attribute. :toctree: api/ :template: autosummary/accessor_method.rst - Expr.arr.max - Expr.arr.min - Expr.arr.median - Expr.arr.sum - Expr.arr.std - Expr.arr.to_list - Expr.arr.unique - Expr.arr.n_unique - Expr.arr.var Expr.arr.all Expr.arr.any - Expr.arr.sort - Expr.arr.reverse - Expr.arr.arg_min Expr.arr.arg_max - Expr.arr.get - Expr.arr.first - Expr.arr.last - Expr.arr.join - Expr.arr.explode + Expr.arr.arg_min Expr.arr.contains Expr.arr.count_matches - Expr.arr.to_struct + Expr.arr.explode + Expr.arr.first + Expr.arr.get + Expr.arr.join + Expr.arr.last + Expr.arr.max + Expr.arr.median + Expr.arr.min + Expr.arr.n_unique + Expr.arr.reverse Expr.arr.shift + Expr.arr.sort + Expr.arr.std + Expr.arr.sum + Expr.arr.to_list + Expr.arr.to_struct + Expr.arr.unique + Expr.arr.var diff --git a/py-polars/docs/source/reference/expressions/col.rst b/py-polars/docs/source/reference/expressions/col.rst index 09b5c33e82f7..612e56e4cd63 100644 --- a/py-polars/docs/source/reference/expressions/col.rst +++ b/py-polars/docs/source/reference/expressions/col.rst @@ -2,7 +2,7 @@ polars.col ========== -Create an expression representing column(s) in a dataframe. +Create an expression representing column(s) in a DataFrame. ``col`` is technically not a function, but it can be used like one. diff --git a/py-polars/docs/source/reference/expressions/computation.rst b/py-polars/docs/source/reference/expressions/computation.rst index 6f90e9c1eab2..46dba474834f 100644 --- a/py-polars/docs/source/reference/expressions/computation.rst +++ b/py-polars/docs/source/reference/expressions/computation.rst @@ -42,7 +42,6 @@ Computation Expr.log1p Expr.mode Expr.n_unique - Expr.null_count Expr.pct_change Expr.peak_max Expr.peak_min diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index 4a8ca0425fca..b27fa4e87f84 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -35,9 +35,9 @@ These functions are available from the Polars module root and can be used as exp cum_sum cum_sum_horizontal date - datetime date_range date_ranges + datetime datetime_range datetime_ranges duration @@ -73,12 +73,12 @@ These functions are available from the Polars module root and can be used as exp rolling_corr rolling_cov select + sql + sql_expr std struct sum sum_horizontal - sql - sql_expr tail time time_range @@ -97,7 +97,6 @@ These functions are available from the Polars module root and can be used as exp Expr.any Expr.approx_n_unique Expr.count - Expr.exclude Expr.first Expr.head Expr.implode diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index 7d330772511d..18b9dd9c4867 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -11,17 +11,18 @@ The following methods are available under the `expr.list` attribute. Expr.list.all Expr.list.any - Expr.list.drop_nulls Expr.list.arg_max Expr.list.arg_min Expr.list.concat Expr.list.contains Expr.list.count_matches Expr.list.diff + Expr.list.drop_nulls Expr.list.eval Expr.list.explode Expr.list.first Expr.list.gather + Expr.list.gather_every Expr.list.get Expr.list.head Expr.list.join @@ -31,6 +32,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.mean Expr.list.median Expr.list.min + Expr.list.n_unique Expr.list.reverse Expr.list.sample Expr.list.set_difference @@ -46,6 +48,4 @@ The following methods are available under the `expr.list` attribute. Expr.list.to_array Expr.list.to_struct Expr.list.unique - Expr.list.n_unique Expr.list.var - Expr.list.gather_every diff --git a/py-polars/docs/source/reference/expressions/name.rst b/py-polars/docs/source/reference/expressions/name.rst index c687651d6278..80693496c350 100644 --- a/py-polars/docs/source/reference/expressions/name.rst +++ b/py-polars/docs/source/reference/expressions/name.rst @@ -11,10 +11,10 @@ The following methods are available under the `expr.name` attribute. Expr.name.keep Expr.name.map + Expr.name.map_fields Expr.name.prefix + Expr.name.prefix_fields Expr.name.suffix + Expr.name.suffix_fields Expr.name.to_lowercase Expr.name.to_uppercase - Expr.name.map_fields - Expr.name.prefix_fields - Expr.name.suffix_fields diff --git a/py-polars/docs/source/reference/expressions/operators.rst b/py-polars/docs/source/reference/expressions/operators.rst index 397a0998a4a4..c4ce55a2b144 100644 --- a/py-polars/docs/source/reference/expressions/operators.rst +++ b/py-polars/docs/source/reference/expressions/operators.rst @@ -42,9 +42,9 @@ Numeric Expr.mod Expr.mul Expr.neg + Expr.pow Expr.sub Expr.truediv - Expr.pow Binary diff --git a/py-polars/docs/source/reference/expressions/string.rst b/py-polars/docs/source/reference/expressions/string.rst index 659cd0dbe40c..a0cde717f0da 100644 --- a/py-polars/docs/source/reference/expressions/string.rst +++ b/py-polars/docs/source/reference/expressions/string.rst @@ -51,7 +51,7 @@ The following methods are available under the `expr.str` attribute. Expr.str.to_decimal Expr.str.to_integer Expr.str.to_lowercase - Expr.str.to_titlecase Expr.str.to_time + Expr.str.to_titlecase Expr.str.to_uppercase Expr.str.zfill diff --git a/py-polars/docs/source/reference/lazyframe/index.rst b/py-polars/docs/source/reference/lazyframe/index.rst index 889437cb8fb9..fe0ea51e60ee 100644 --- a/py-polars/docs/source/reference/lazyframe/index.rst +++ b/py-polars/docs/source/reference/lazyframe/index.rst @@ -11,9 +11,9 @@ This page gives an overview of all public LazyFrame methods. aggregation attributes descriptive + group_by modify_select miscellaneous - group_by in_process .. _lazyframe: diff --git a/py-polars/docs/source/reference/lazyframe/modify_select.rst b/py-polars/docs/source/reference/lazyframe/modify_select.rst index 925591ed8649..f26a600966d2 100644 --- a/py-polars/docs/source/reference/lazyframe/modify_select.rst +++ b/py-polars/docs/source/reference/lazyframe/modify_select.rst @@ -26,6 +26,7 @@ Manipulation/selection LazyFrame.interpolate LazyFrame.join LazyFrame.join_asof + LazyFrame.join_where LazyFrame.last LazyFrame.limit LazyFrame.melt diff --git a/py-polars/docs/source/reference/series/aggregation.rst b/py-polars/docs/source/reference/series/aggregation.rst index 2f6f8776ea34..fe74d9eb4fd0 100644 --- a/py-polars/docs/source/reference/series/aggregation.rst +++ b/py-polars/docs/source/reference/series/aggregation.rst @@ -8,6 +8,7 @@ Aggregation Series.arg_max Series.arg_min + Series.count Series.implode Series.max Series.mean diff --git a/py-polars/docs/source/reference/series/array.rst b/py-polars/docs/source/reference/series/array.rst index 28976e1cab7d..92effb371544 100644 --- a/py-polars/docs/source/reference/series/array.rst +++ b/py-polars/docs/source/reference/series/array.rst @@ -9,27 +9,27 @@ The following methods are available under the `Series.arr` attribute. :toctree: api/ :template: autosummary/accessor_method.rst - Series.arr.max - Series.arr.min - Series.arr.median - Series.arr.sum - Series.arr.std - Series.arr.to_list - Series.arr.unique - Series.arr.n_unique - Series.arr.var Series.arr.all Series.arr.any - Series.arr.sort - Series.arr.reverse - Series.arr.arg_min Series.arr.arg_max - Series.arr.get - Series.arr.first - Series.arr.last - Series.arr.join - Series.arr.explode + Series.arr.arg_min Series.arr.contains Series.arr.count_matches + Series.arr.explode + Series.arr.first + Series.arr.get + Series.arr.join + Series.arr.last + Series.arr.max + Series.arr.median + Series.arr.min + Series.arr.n_unique + Series.arr.reverse + Series.arr.shift + Series.arr.sort + Series.arr.std + Series.arr.sum + Series.arr.to_list Series.arr.to_struct - Series.arr.shift \ No newline at end of file + Series.arr.unique + Series.arr.var diff --git a/py-polars/docs/source/reference/series/attributes.rst b/py-polars/docs/source/reference/series/attributes.rst index 50000e43edc7..2e2deb52f890 100644 --- a/py-polars/docs/source/reference/series/attributes.rst +++ b/py-polars/docs/source/reference/series/attributes.rst @@ -7,6 +7,6 @@ Attributes :toctree: api/ Series.dtype + Series.flags Series.name Series.shape - Series.flags diff --git a/py-polars/docs/source/reference/series/export.rst b/py-polars/docs/source/reference/series/export.rst index c1c7bacf8086..268ef25113fd 100644 --- a/py-polars/docs/source/reference/series/export.rst +++ b/py-polars/docs/source/reference/series/export.rst @@ -8,11 +8,13 @@ Export Series data to other formats: .. autosummary:: :toctree: api/ + Series.__array__ + Series.__arrow_c_stream__ Series.to_arrow Series.to_frame + Series.to_init_repr Series.to_jax Series.to_list Series.to_numpy Series.to_pandas - Series.to_init_repr Series.to_torch diff --git a/py-polars/docs/source/reference/series/index.rst b/py-polars/docs/source/reference/series/index.rst index a8476da64b97..d507312498c9 100644 --- a/py-polars/docs/source/reference/series/index.rst +++ b/py-polars/docs/source/reference/series/index.rst @@ -20,6 +20,7 @@ This page gives an overview of all public Series methods. list modify_select miscellaneous + operators plot string struct diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index a9a11001402b..ab857679c059 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -11,17 +11,18 @@ The following methods are available under the `Series.list` attribute. Series.list.all Series.list.any - Series.list.drop_nulls Series.list.arg_max Series.list.arg_min Series.list.concat Series.list.contains Series.list.count_matches Series.list.diff + Series.list.drop_nulls Series.list.eval Series.list.explode Series.list.first Series.list.gather + Series.list.gather_every Series.list.get Series.list.head Series.list.join @@ -31,6 +32,7 @@ The following methods are available under the `Series.list` attribute. Series.list.mean Series.list.median Series.list.min + Series.list.n_unique Series.list.reverse Series.list.sample Series.list.set_difference @@ -46,6 +48,4 @@ The following methods are available under the `Series.list` attribute. Series.list.to_array Series.list.to_struct Series.list.unique - Series.list.n_unique Series.list.var - Series.list.gather_every diff --git a/py-polars/docs/source/reference/series/miscellaneous.rst b/py-polars/docs/source/reference/series/miscellaneous.rst index 7928cf87ad4e..729071b69b95 100644 --- a/py-polars/docs/source/reference/series/miscellaneous.rst +++ b/py-polars/docs/source/reference/series/miscellaneous.rst @@ -7,8 +7,8 @@ Miscellaneous :toctree: api/ Series.equals + Series.get_chunks Series.map_elements Series.reinterpret Series.set_sorted Series.to_physical - Series.get_chunks diff --git a/py-polars/docs/source/reference/series/modify_select.rst b/py-polars/docs/source/reference/series/modify_select.rst index d7ad90029349..3b15ec11ecb3 100644 --- a/py-polars/docs/source/reference/series/modify_select.rst +++ b/py-polars/docs/source/reference/series/modify_select.rst @@ -6,6 +6,7 @@ Manipulation/selection .. autosummary:: :toctree: api/ + Series.__getitem__ Series.alias Series.append Series.arg_sort diff --git a/py-polars/docs/source/reference/series/operators.rst b/py-polars/docs/source/reference/series/operators.rst new file mode 100644 index 000000000000..e01c1b39e9de --- /dev/null +++ b/py-polars/docs/source/reference/series/operators.rst @@ -0,0 +1,31 @@ +========= +Operators +========= + +Polars supports native Python operators for all common operations; +many of these operators are also available as methods on the :class:`Series` +class. + +Comparison +~~~~~~~~~~ + +.. currentmodule:: polars +.. autosummary:: + :toctree: api/ + + Series.eq + Series.eq_missing + Series.ge + Series.gt + Series.le + Series.lt + Series.ne + Series.ne_missing + +Numeric +~~~~~~~ + +.. autosummary:: + :toctree: api/ + + Series.pow diff --git a/py-polars/docs/source/reference/sql/functions/aggregate.rst b/py-polars/docs/source/reference/sql/functions/aggregate.rst index 0c574283d4b5..0244af235175 100644 --- a/py-polars/docs/source/reference/sql/functions/aggregate.rst +++ b/py-polars/docs/source/reference/sql/functions/aggregate.rst @@ -202,7 +202,7 @@ Returns the smallest (minimum) of all the elements in the grouping. STDDEV ------ -Returns the standard deviation of all the elements in the grouping. +Returns the sample standard deviation of all the elements in the grouping. .. admonition:: Aliases diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 428c13da0e96..9b0cc722de57 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -70,6 +70,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: Type[List[Any]], Type[Tuple[Any, ...]], Type[bytes], + Type[object], Type["Decimal"], Type[None], ] diff --git a/py-polars/polars/_utils/cloud.py b/py-polars/polars/_utils/cloud.py index 5b427fce4059..62d1dfd3b6ec 100644 --- a/py-polars/polars/_utils/cloud.py +++ b/py-polars/polars/_utils/cloud.py @@ -3,17 +3,13 @@ from typing import TYPE_CHECKING import polars.polars as plr -from polars._utils.various import normalize_filepath if TYPE_CHECKING: - from pathlib import Path - from polars import LazyFrame def prepare_cloud_plan( lf: LazyFrame, - uri: Path | str, **optimizations: bool, ) -> bytes: """ @@ -23,9 +19,6 @@ def prepare_cloud_plan( ---------- lf The LazyFrame to prepare. - uri - Path to which the file should be written. - Must be a URI to an accessible object store location. **optimizations Optimizations to enable or disable in the query optimizer, e.g. `projection_pushdown=False`. @@ -41,6 +34,5 @@ def prepare_cloud_plan( ComputeError If the given LazyFrame cannot be serialized. """ - uri = normalize_filepath(uri) pylf = lf._set_sink_optimizations(**optimizations) - return plr.prepare_cloud_plan(pylf, uri) + return plr.prepare_cloud_plan(pylf) diff --git a/py-polars/polars/_utils/construction/dataframe.py b/py-polars/polars/_utils/construction/dataframe.py index 680c29e19ba3..90b7ef485655 100644 --- a/py-polars/polars/_utils/construction/dataframe.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1044,13 +1044,17 @@ def to_frame_chunk(values: list[Any], schema: SchemaDefinition | None) -> DataFr return df._df -def _check_pandas_columns(data: pd.DataFrame) -> None: +def _check_pandas_columns(data: pd.DataFrame, *, include_index: bool) -> None: """Check pandas dataframe columns can be converted to polars.""" stringified_cols: set[str] = {str(col) for col in data.columns} - stringified_index: set[str] = {str(idx) for idx in data.index.names} + stringified_index: set[str] = ( + {str(idx) for idx in data.index.names} if include_index else set() + ) non_unique_cols: bool = len(stringified_cols) < len(data.columns) - non_unique_indices: bool = len(stringified_index) < len(data.index.names) + non_unique_indices: bool = ( + (len(stringified_index) < len(data.index.names)) if include_index else False + ) if non_unique_cols or non_unique_indices: msg = ( "Pandas dataframe contains non-unique indices and/or column names. " @@ -1075,7 +1079,7 @@ def pandas_to_pydf( include_index: bool = False, ) -> PyDataFrame: """Construct a PyDataFrame from a pandas DataFrame.""" - _check_pandas_columns(data) + _check_pandas_columns(data, include_index=include_index) convert_index = include_index and not _pandas_has_default_index(data) if not convert_index and all( diff --git a/py-polars/polars/_utils/construction/series.py b/py-polars/polars/_utils/construction/series.py index f13b9f5b0ec5..379bdbeb0a30 100644 --- a/py-polars/polars/_utils/construction/series.py +++ b/py-polars/polars/_utils/construction/series.py @@ -179,7 +179,7 @@ def sequence_to_pyseries( python_dtype = type(value) # temporal branch - if python_dtype in py_temporal_types: + if issubclass(python_dtype, tuple(py_temporal_types)): if dtype is None: dtype = parse_into_dtype(python_dtype) # construct from integer elif dtype in py_temporal_types: diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 014e601de8e2..f82bbec0d785 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -84,6 +84,24 @@ def _is_iterable_of(val: Iterable[object], eltype: type | tuple[type, ...]) -> b return all(isinstance(x, eltype) for x in val) +def is_path_or_str_sequence( + val: object, *, allow_str: bool = False, include_series: bool = False +) -> TypeGuard[Sequence[str | Path]]: + """ + Check that `val` is a sequence of strings or paths. + + Note that a single string is a sequence of strings by definition, use + `allow_str=False` to return False on a single string. + """ + if allow_str is False and isinstance(val, str): + return False + elif _check_for_numpy(val) and isinstance(val, np.ndarray): + return np.issubdtype(val.dtype, np.str_) + elif include_series and isinstance(val, pl.Series): + return val.dtype == pl.String + return isinstance(val, Sequence) and _is_iterable_of(val, (Path, str)) + + def is_bool_sequence( val: object, *, include_series: bool = False ) -> TypeGuard[Sequence[bool]]: diff --git a/py-polars/polars/api.py b/py-polars/polars/api.py index 44ba02084778..84262d1b7998 100644 --- a/py-polars/polars/api.py +++ b/py-polars/polars/api.py @@ -1,7 +1,5 @@ from __future__ import annotations -from functools import reduce -from operator import or_ from typing import TYPE_CHECKING, Callable, Generic, TypeVar from warnings import warn @@ -20,9 +18,8 @@ ] # do not allow override of polars' own namespaces (as registered by '_accessors') -_reserved_namespaces: set[str] = reduce( - or_, - (cls._accessors for cls in (pl.DataFrame, pl.Expr, pl.LazyFrame, pl.Series)), +_reserved_namespaces: set[str] = set.union( + *(cls._accessors for cls in (pl.DataFrame, pl.Expr, pl.LazyFrame, pl.Series)) ) diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index 6cb01bdafb9e..e2f7f27292d2 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -21,7 +21,6 @@ __all__ = ["Config"] - TableFormatNames: TypeAlias = Literal[ "ASCII_FULL", "ASCII_FULL_CONDENSED", @@ -177,7 +176,7 @@ def __exit__( self._original_state = "" @classmethod - def load(cls, cfg: str) -> type[Config]: + def load(cls, cfg: str) -> Config: """ Load (and set) previously saved Config options from a JSON string. @@ -197,14 +196,21 @@ def load(cls, cfg: str) -> type[Config]: msg = "invalid Config string (did you mean to use `load_from_file`?)" raise ValueError(msg) from err - os.environ.update(options.get("environment", {})) + cfg_load = Config() + opts = options.get("environment", {}) + for key, opt in opts.items(): + if opt is None: + os.environ.pop(key, None) + else: + os.environ[key] = opt + for cfg_methodname, value in options.get("direct", {}).items(): - if hasattr(cls, cfg_methodname): - getattr(cls, cfg_methodname)(value) - return cls + if hasattr(cfg_load, cfg_methodname): + getattr(cfg_load, cfg_methodname)(value) + return cfg_load @classmethod - def load_from_file(cls, file: Path | str) -> type[Config]: + def load_from_file(cls, file: Path | str) -> Config: """ Load (and set) previously saved Config options from file. @@ -251,10 +257,16 @@ def restore_defaults(cls) -> type[Config]: return cls @classmethod - def save(cls) -> str: + def save(cls, *, if_set: bool = False) -> str: """ Save the current set of Config options as a JSON string. + Parameters + ---------- + if_set + By default this will save the state of all configuration options; set + to `False` to save only those that have been set to a non-default value. + See Also -------- load : Load (and set) Config options from a JSON string. @@ -263,7 +275,7 @@ def save(cls) -> str: Examples -------- - >>> json_str = pl.Config.save() + >>> json_state = pl.Config.save() Returns ------- @@ -271,9 +283,9 @@ def save(cls) -> str: JSON string containing current Config options. """ environment_vars = { - key: os.environ[key] + key: os.environ.get(key) for key in sorted(_POLARS_CFG_ENV_VARS) - if (key in os.environ) + if not if_set or (os.environ.get(key) is not None) } direct_vars = { cfg_methodname: get_value() diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 444710793f2c..8a023b4e01ae 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -66,6 +66,7 @@ from polars._utils.wrap import wrap_expr, wrap_ldf, wrap_s from polars.dataframe._html import NotebookFormatter from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy +from polars.dataframe.plotting import DataFramePlot from polars.datatypes import ( N_INFER_DEFAULT, Boolean, @@ -82,15 +83,15 @@ ) from polars.datatypes.group import INTEGER_DTYPES from polars.dependencies import ( + _ALTAIR_AVAILABLE, _GREAT_TABLES_AVAILABLE, - _HVPLOT_AVAILABLE, _PANDAS_AVAILABLE, _PYARROW_AVAILABLE, _check_for_numpy, _check_for_pandas, _check_for_pyarrow, + altair, great_tables, - hvplot, import_optional, ) from polars.dependencies import numpy as np @@ -123,8 +124,8 @@ import numpy.typing as npt import torch from great_tables import GT - from hvplot.plotting.core import hvPlotTabularPolars - from xlsxwriter import Workbook, Worksheet + from xlsxwriter import Workbook + from xlsxwriter.worksheet import Worksheet from polars import DataType, Expr, LazyFrame, Series from polars._typing import ( @@ -603,7 +604,7 @@ def _replace(self, column: str, new_column: Series) -> DataFrame: @property @unstable() - def plot(self) -> hvPlotTabularPolars: + def plot(self) -> DataFramePlot: """ Create a plot namespace. @@ -611,9 +612,28 @@ def plot(self) -> hvPlotTabularPolars: This functionality is currently considered **unstable**. It may be changed at any point without it being considered a breaking change. + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + Polars does not implement plotting logic itself, but instead defers to - hvplot. Please see the `hvplot reference gallery `_ - for more information and documentation. + `Altair `_: + + - `df.plot.line(**kwargs)` + is shorthand for + `alt.Chart(df).mark_line().encode(**kwargs).interactive()` + - `df.plot.point(**kwargs)` + is shorthand for + `alt.Chart(df).mark_point().encode(**kwargs).interactive()` (and + `plot.scatter` is provided as an alias) + - `df.plot.bar(**kwargs)` + is shorthand for + `alt.Chart(df).mark_bar().encode(**kwargs).interactive()` + - for any other attribute `attr`, `df.plot.attr(**kwargs)` + is shorthand for + `alt.Chart(df).mark_attr().encode(**kwargs).interactive()` Examples -------- @@ -626,32 +646,37 @@ def plot(self) -> hvPlotTabularPolars: ... "species": ["setosa", "setosa", "versicolor"], ... } ... ) - >>> df.plot.scatter(x="length", y="width", by="species") # doctest: +SKIP + >>> df.plot.point(x="length", y="width", color="species") # doctest: +SKIP Line plot: >>> from datetime import date >>> df = pl.DataFrame( ... { - ... "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 4)], - ... "stock_1": [1, 4, 6], - ... "stock_2": [1, 5, 2], + ... "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 4)] * 2, + ... "price": [1, 4, 6, 1, 5, 2], + ... "stock": ["a", "a", "a", "b", "b", "b"], ... } ... ) - >>> df.plot.line(x="date", y=["stock_1", "stock_2"]) # doctest: +SKIP + >>> df.plot.line(x="date", y="price", color="stock") # doctest: +SKIP - For more info on what you can pass, you can use ``hvplot.help``: + Bar plot: - >>> import hvplot # doctest: +SKIP - >>> hvplot.help("scatter") # doctest: +SKIP + >>> df = pl.DataFrame( + ... { + ... "day": ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] * 2, + ... "group": ["a"] * 7 + ["b"] * 7, + ... "value": [1, 3, 2, 4, 5, 6, 1, 1, 3, 2, 4, 5, 1, 2], + ... } + ... ) + >>> df.plot.bar( + ... x="day", y="value", color="day", column="group" + ... ) # doctest: +SKIP """ - if not _HVPLOT_AVAILABLE or parse_version(hvplot.__version__) < parse_version( - "0.9.1" - ): - msg = "hvplot>=0.9.1 is required for `.plot`" + if not _ALTAIR_AVAILABLE or parse_version(altair.__version__) < (5, 4, 0): + msg = "altair>=5.4.0 is required for `.plot`" raise ModuleUpgradeRequiredError(msg) - hvplot.post_patch() - return hvplot.plotting.core.hvPlotTabularPolars(self) + return DataFramePlot(self) @property @unstable() @@ -1194,7 +1219,130 @@ def __getitem__( | tuple[MultiIndexSelector, MultiColSelector] ), ) -> DataFrame | Series | Any: - """Get part of the DataFrame as a new DataFrame, Series, or scalar.""" + """ + Get part of the DataFrame as a new DataFrame, Series, or scalar. + + Parameters + ---------- + key + Rows / columns to select. This is easiest to explain via example. Suppose + we have a DataFrame with columns `'a'`, `'d'`, `'c'`, `'d'`. Here is what + various types of `key` would do: + + - `df[0, 'a']` extracts the first element of column `'a'` and returns a + scalar. + - `df[0]` extracts the first row and returns a Dataframe. + - `df['a']` extracts column `'a'` and returns a Series. + - `df[0:2]` extracts the first two rows and returns a Dataframe. + - `df[0:2, 'a']` extracts the first two rows from column `'a'` and returns + a Series. + - `df[0:2, 0]` extracts the first two rows from the first column and returns + a Series. + - `df[[0, 1], [0, 1, 2]]` extracts the first two rows and the first three + columns and returns a Dataframe. + - `df[0: 2, ['a', 'c']]` extracts the first two rows from columns `'a'` and + `'c'` and returns a Dataframe. + - `df[:, 0: 2]` extracts all rows from the first two columns and returns a + Dataframe. + - `df[:, 'a': 'c']` extracts all rows and all columns positioned between + `'a'` and `'c'` *inclusive* and returns a Dataframe. In our example, + that would extract columns `'a'`, `'d'`, and `'c'`. + + Returns + ------- + DataFrame, Series, or scalar, depending on `key`. + + Examples + -------- + >>> df = pl.DataFrame( + ... {"a": [1, 2, 3], "d": [4, 5, 6], "c": [1, 3, 2], "b": [7, 8, 9]} + ... ) + >>> df[0] + shape: (1, 4) + ┌─────┬─────┬─────┬─────┐ + │ a ┆ d ┆ c ┆ b │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╪═════╡ + │ 1 ┆ 4 ┆ 1 ┆ 7 │ + └─────┴─────┴─────┴─────┘ + >>> df[0, "a"] + 1 + >>> df["a"] + shape: (3,) + Series: 'a' [i64] + [ + 1 + 2 + 3 + ] + >>> df[0:2] + shape: (2, 4) + ┌─────┬─────┬─────┬─────┐ + │ a ┆ d ┆ c ┆ b │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╪═════╡ + │ 1 ┆ 4 ┆ 1 ┆ 7 │ + │ 2 ┆ 5 ┆ 3 ┆ 8 │ + └─────┴─────┴─────┴─────┘ + >>> df[0:2, "a"] + shape: (2,) + Series: 'a' [i64] + [ + 1 + 2 + ] + >>> df[0:2, 0] + shape: (2,) + Series: 'a' [i64] + [ + 1 + 2 + ] + >>> df[[0, 1], [0, 1, 2]] + shape: (2, 3) + ┌─────┬─────┬─────┐ + │ a ┆ d ┆ c │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ 1 ┆ 4 ┆ 1 │ + │ 2 ┆ 5 ┆ 3 │ + └─────┴─────┴─────┘ + >>> df[0:2, ["a", "c"]] + shape: (2, 2) + ┌─────┬─────┐ + │ a ┆ c │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 1 │ + │ 2 ┆ 3 │ + └─────┴─────┘ + >>> df[:, 0:2] + shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ d │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + >>> df[:, "a":"c"] + shape: (3, 3) + ┌─────┬─────┬─────┐ + │ a ┆ d ┆ c │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╡ + │ 1 ┆ 4 ┆ 1 │ + │ 2 ┆ 5 ┆ 3 │ + │ 3 ┆ 6 ┆ 2 │ + └─────┴─────┴─────┘ + """ return get_df_item_by_key(self, key) def __setitem__( @@ -2709,10 +2857,20 @@ def write_csv( if not null_value: null_value = None + def write_csv_to_string() -> str: + with BytesIO() as buf: + self.write_csv(buf) + csv_bytes = buf.getvalue() + return csv_bytes.decode("utf8") + should_return_buffer = False if file is None: buffer = file = BytesIO() should_return_buffer = True + elif isinstance(file, StringIO): + csv_str = write_csv_to_string() + file.write(csv_str) + return None elif isinstance(file, (str, os.PathLike)): file = normalize_filepath(file) @@ -3820,6 +3978,7 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]: mode=mode, catalog_name=catalog, db_schema_name=db_schema, + **(engine_options or {}), ) elif db_schema is not None: adbc_str_version = ".".join(str(v) for v in adbc_version) @@ -5208,7 +5367,7 @@ def equals(self, other: DataFrame, *, null_equal: bool = True) -> bool: See Also -------- - assert_frame_equal + polars.testing.assert_frame_equal Examples -------- @@ -5810,7 +5969,7 @@ def group_by( >>> for name, data in df.group_by("a"): # doctest: +SKIP ... print(name) ... print(data) - a + ('a',) shape: (2, 3) ┌─────┬─────┬─────┐ │ a ┆ b ┆ c │ @@ -5820,7 +5979,7 @@ def group_by( │ a ┆ 1 ┆ 5 │ │ a ┆ 1 ┆ 3 │ └─────┴─────┴─────┘ - b + ('b',) shape: (2, 3) ┌─────┬─────┬─────┐ │ a ┆ b ┆ c │ @@ -5830,7 +5989,7 @@ def group_by( │ b ┆ 2 ┆ 4 │ │ b ┆ 3 ┆ 2 │ └─────┴─────┴─────┘ - c + ('c',) shape: (1, 3) ┌─────┬─────┬─────┐ │ a ┆ b ┆ c │ @@ -6179,7 +6338,7 @@ def group_by_dynamic( │ 2021-12-16 03:00:00 ┆ 6 │ └─────────────────────┴─────┘ - Group by windows of 1 hour starting at 2021-12-16 00:00:00. + Group by windows of 1 hour. >>> df.group_by_dynamic("time", every="1h", closed="right").agg(pl.col("n")) shape: (4, 2) @@ -6438,7 +6597,7 @@ def join_asof( tolerance: str | int | float | timedelta | None = None, allow_parallel: bool = True, force_parallel: bool = False, - coalesce: bool | None = None, + coalesce: bool = True, ) -> DataFrame: """ Perform an asof join. @@ -6516,9 +6675,8 @@ def join_asof( Force the physical plan to evaluate the computation of both DataFrames up to the join in parallel. coalesce - Coalescing behavior (merging of join columns). + Coalescing behavior (merging of `on` / `left_on` / `right_on` columns): - - None: -> join specific. - True: -> Always coalesce join columns. - False: -> Never coalesce join columns. @@ -6592,6 +6750,20 @@ def join_asof( - date `2016-03-01` from `population` is matched with `2016-01-01` from `gdp`; - date `2018-08-01` from `population` is matched with `2018-01-01` from `gdp`. + You can verify this by passing `coalesce=False`: + + >>> population.join_asof(gdp, on="date", strategy="backward", coalesce=False) + shape: (3, 4) + ┌────────────┬────────────┬────────────┬──────┐ + │ date ┆ population ┆ date_right ┆ gdp │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ date ┆ i64 │ + ╞════════════╪════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 2016-01-01 ┆ 4164 │ + │ 2018-08-01 ┆ 82.66 ┆ 2018-01-01 ┆ 4566 │ + │ 2019-01-01 ┆ 83.12 ┆ 2019-01-01 ┆ 4696 │ + └────────────┴────────────┴────────────┴──────┘ + If we instead use `strategy='forward'`, then each date from `population` which doesn't have an exact match is matched with the closest later date from `gdp`: @@ -6823,10 +6995,6 @@ def join( Note that joining on any other expressions than `col` will turn off coalescing. - Returns - ------- - DataFrame - See Also -------- join_asof @@ -6927,6 +7095,94 @@ def join( .collect(_eager=True) ) + @unstable() + def join_where( + self, + other: DataFrame, + *predicates: Expr | Iterable[Expr], + suffix: str = "_right", + ) -> DataFrame: + """ + Perform a join based on one or multiple equality predicates. + + .. warning:: + This functionality is experimental. It may be + changed at any point without it being considered a breaking change. + + A row from this table may be included in zero or multiple rows in the result, + and the relative order of rows may differ between the input and output tables. + + Parameters + ---------- + other + DataFrame to join with. + *predicates + (In)Equality condition to join the two table on. + The left `pl.col(..)` will refer to the left table + and the right `pl.col(..)` + to the right table. + For example: `pl.col("time") >= pl.col("duration")` + suffix + Suffix to append to columns with a duplicate name. + + Notes + ----- + This method is strict about its equality expressions. + Only 1 equality expression is allowed per predicate, where + the lhs `pl.col` refers to the left table in the join, and the + rhs `pl.col` refers to the right table. + + Examples + -------- + >>> east = pl.DataFrame( + ... { + ... "id": [100, 101, 102], + ... "dur": [120, 140, 160], + ... "rev": [12, 14, 16], + ... "cores": [2, 8, 4], + ... } + ... ) + >>> west = pl.DataFrame( + ... { + ... "t_id": [404, 498, 676, 742], + ... "time": [90, 130, 150, 170], + ... "cost": [9, 13, 15, 16], + ... "cores": [4, 2, 1, 4], + ... } + ... ) + >>> east.join_where( + ... west, + ... pl.col("dur") < pl.col("time"), + ... pl.col("rev") < pl.col("cost"), + ... ) + shape: (5, 8) + ┌─────┬─────┬─────┬───────┬──────┬──────┬──────┬─────────────┐ + │ id ┆ dur ┆ rev ┆ cores ┆ t_id ┆ time ┆ cost ┆ cores_right │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╪═══════╪══════╪══════╪══════╪═════════════╡ + │ 100 ┆ 120 ┆ 12 ┆ 2 ┆ 498 ┆ 130 ┆ 13 ┆ 2 │ + │ 100 ┆ 120 ┆ 12 ┆ 2 ┆ 676 ┆ 150 ┆ 15 ┆ 1 │ + │ 100 ┆ 120 ┆ 12 ┆ 2 ┆ 742 ┆ 170 ┆ 16 ┆ 4 │ + │ 101 ┆ 140 ┆ 14 ┆ 8 ┆ 676 ┆ 150 ┆ 15 ┆ 1 │ + │ 101 ┆ 140 ┆ 14 ┆ 8 ┆ 742 ┆ 170 ┆ 16 ┆ 4 │ + └─────┴─────┴─────┴───────┴──────┴──────┴──────┴─────────────┘ + + """ + if not isinstance(other, DataFrame): + msg = f"expected `other` join table to be a DataFrame, got {type(other).__name__!r}" + raise TypeError(msg) + + return ( + self.lazy() + .join_where( + other.lazy(), + *predicates, + suffix=suffix, + ) + .collect(_eager=True) + ) + def map_rows( self, function: Callable[[tuple[Any, ...]], Any], @@ -7003,17 +7259,17 @@ def map_rows( Return a DataFrame with a single column by mapping each row to a scalar: - >>> df.map_rows(lambda t: (t[0] * 2 + t[1])) # doctest: +SKIP + >>> df.map_rows(lambda t: (t[0] * 2 + t[1])) shape: (3, 1) - ┌───────┐ - │ apply │ - │ --- │ - │ i64 │ - ╞═══════╡ - │ 1 │ - │ 9 │ - │ 14 │ - └───────┘ + ┌─────┐ + │ map │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 1 │ + │ 9 │ + │ 14 │ + └─────┘ In this case it is better to use the following native expression: @@ -8581,17 +8837,15 @@ def lazy(self) -> LazyFrame: """ Start a lazy query from this point. This returns a `LazyFrame` object. - Operations on a `LazyFrame` are not executed until this is requested by either - calling: + Operations on a `LazyFrame` are not executed until this is triggered + by calling one of: * :meth:`.collect() ` (run on all data) - * :meth:`.describe_plan() ` - (print unoptimized query plan) - * :meth:`.describe_optimized_plan() ` - (print optimized query plan) + * :meth:`.explain() ` + (print the query plan) * :meth:`.show_graph() ` - (show (un)optimized query plan as graphviz graph) + (show the query plan as graphviz graph) * :meth:`.collect_schema() ` (return the final frame schema) diff --git a/py-polars/polars/dataframe/group_by.py b/py-polars/polars/dataframe/group_by.py index 42b386fdda7d..f3a9c185de41 100644 --- a/py-polars/polars/dataframe/group_by.py +++ b/py-polars/polars/dataframe/group_by.py @@ -93,13 +93,15 @@ def __iter__(self) -> Self: │ b ┆ 3 │ └─────┴─────┘ """ + # Every group gather can trigger a rechunk, so do early. + self.df = self.df.rechunk() temp_col = "__POLARS_GB_GROUP_INDICES" groups_df = ( self.df.lazy() .group_by(*self.by, **self.named_by, maintain_order=self.maintain_order) .agg(F.first().agg_groups().alias(temp_col)) .collect(no_optimization=True) - ).rechunk() + ) self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows() self._group_indices = groups_df.select(temp_col).to_series() diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py new file mode 100644 index 000000000000..ed118e504656 --- /dev/null +++ b/py-polars/polars/dataframe/plotting.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Dict, Union + +if TYPE_CHECKING: + import sys + + import altair as alt + from altair.typing import ( + ChannelColor, + ChannelOrder, + ChannelSize, + ChannelTooltip, + ChannelX, + ChannelY, + EncodeKwds, + ) + + from polars import DataFrame + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + if sys.version_info >= (3, 11): + from typing import Unpack + else: + from typing_extensions import Unpack + + Encodings: TypeAlias = Dict[ + str, + Union[ + ChannelX, ChannelY, ChannelColor, ChannelOrder, ChannelSize, ChannelTooltip + ], + ] + + +class DataFramePlot: + """DataFrame.plot namespace.""" + + def __init__(self, df: DataFrame) -> None: + import altair as alt + + self._chart = alt.Chart(df) + + def bar( + self, + x: ChannelX | None = None, + y: ChannelY | None = None, + color: ChannelColor | None = None, + tooltip: ChannelTooltip | None = None, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw bar plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `df.plot.bar(**kwargs)` is shorthand for + `alt.Chart(df).mark_bar().encode(**kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + x + Column with x-coordinates of bars. + y + Column with y-coordinates of bars. + color + Column to color bars by. + tooltip + Columns to show values of when hovering over bars with pointer. + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "day": ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] * 2, + ... "group": ["a"] * 7 + ["b"] * 7, + ... "value": [1, 3, 2, 4, 5, 6, 1, 1, 3, 2, 4, 5, 1, 2], + ... } + ... ) + >>> df.plot.bar( + ... x="day", y="value", color="day", column="group" + ... ) # doctest: +SKIP + """ + encodings: Encodings = {} + if x is not None: + encodings["x"] = x + if y is not None: + encodings["y"] = y + if color is not None: + encodings["color"] = color + if tooltip is not None: + encodings["tooltip"] = tooltip + return self._chart.mark_bar().encode(**encodings, **kwargs).interactive() + + def line( + self, + x: ChannelX | None = None, + y: ChannelY | None = None, + color: ChannelColor | None = None, + order: ChannelOrder | None = None, + tooltip: ChannelTooltip | None = None, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw line plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `alt.Chart(df).mark_line().encode(**kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + x + Column with x-coordinates of lines. + y + Column with y-coordinates of lines. + color + Column to color lines by. + order + Column to use for order of data points in lines. + tooltip + Columns to show values of when hovering over lines with pointer. + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> from datetime import date + >>> df = pl.DataFrame( + ... { + ... "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 4)] * 2, + ... "price": [1, 4, 6, 1, 5, 2], + ... "stock": ["a", "a", "a", "b", "b", "b"], + ... } + ... ) + >>> df.plot.line(x="date", y="price", color="stock") # doctest: +SKIP + """ + encodings: Encodings = {} + if x is not None: + encodings["x"] = x + if y is not None: + encodings["y"] = y + if color is not None: + encodings["color"] = color + if order is not None: + encodings["order"] = order + if tooltip is not None: + encodings["tooltip"] = tooltip + return self._chart.mark_line().encode(**encodings, **kwargs).interactive() + + def point( + self, + x: ChannelX | None = None, + y: ChannelY | None = None, + color: ChannelColor | None = None, + size: ChannelSize | None = None, + tooltip: ChannelTooltip | None = None, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw scatter plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `df.plot.point(**kwargs)` is shorthand for + `alt.Chart(df).mark_point().encode(**kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + x + Column with x-coordinates of points. + y + Column with y-coordinates of points. + color + Column to color points by. + size + Column which determines points' sizes. + tooltip + Columns to show values of when hovering over points with pointer. + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "length": [1, 4, 6], + ... "width": [4, 5, 6], + ... "species": ["setosa", "setosa", "versicolor"], + ... } + ... ) + >>> df.plot.point(x="length", y="width", color="species") # doctest: +SKIP + """ + encodings: Encodings = {} + if x is not None: + encodings["x"] = x + if y is not None: + encodings["y"] = y + if color is not None: + encodings["color"] = color + if size is not None: + encodings["size"] = size + if tooltip is not None: + encodings["tooltip"] = tooltip + return ( + self._chart.mark_point() + .encode( + **encodings, + **kwargs, + ) + .interactive() + ) + + # Alias to `point` because of how common it is. + scatter = point + + def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: + method = getattr(self._chart, f"mark_{attr}", None) + if method is None: + msg = "Altair has no method 'mark_{attr}'" + raise AttributeError(msg) + return lambda **kwargs: method().encode(**kwargs).interactive() diff --git a/py-polars/polars/datatypes/_parse.py b/py-polars/polars/datatypes/_parse.py index 2649bc7905ec..e7ac78cae6dd 100644 --- a/py-polars/polars/datatypes/_parse.py +++ b/py-polars/polars/datatypes/_parse.py @@ -76,10 +76,10 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData return String() elif input is bool: return Boolean() - elif input is date: - return Date() - elif input is datetime: + elif isinstance(input, type) and issubclass(input, datetime): # type: ignore[redundant-expr] return Datetime("us") + elif isinstance(input, type) and issubclass(input, date): # type: ignore[redundant-expr] + return Date() elif input is timedelta: return Duration elif input is time: @@ -97,16 +97,14 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData # this is required as pass through. Don't remove elif input == Unknown: return Unknown - elif hasattr(input, "__origin__") and hasattr(input, "__args__"): return _parse_generic_into_dtype(input) - else: _raise_on_invalid_dtype(input) def _parse_generic_into_dtype(input: Any) -> PolarsDataType: - """Parse a generic type into a Polars data type.""" + """Parse a generic type (from typing annotation) into a Polars data type.""" base_type = input.__origin__ if base_type not in (tuple, list): _raise_on_invalid_dtype(input) @@ -124,19 +122,19 @@ def _parse_generic_into_dtype(input: Any) -> PolarsDataType: PY_TYPE_STR_TO_DTYPE: SchemaDict = { - "int": Int64(), - "float": Float64(), + "Decimal": Decimal, + "NoneType": Null(), "bool": Boolean(), - "str": String(), "bytes": Binary(), "date": Date(), - "time": Time(), "datetime": Datetime("us"), + "float": Float64(), + "int": Int64(), + "list": List, "object": Object(), - "NoneType": Null(), + "str": String(), + "time": Time(), "timedelta": Duration, - "Decimal": Decimal, - "list": List, "tuple": List, } @@ -177,5 +175,7 @@ def _parse_union_type_into_dtype(input: Any) -> PolarsDataType: def _raise_on_invalid_dtype(input: Any) -> NoReturn: """Raise an informative error if the input could not be parsed.""" - msg = f"cannot parse input of type {type(input).__name__!r} into Polars data type: {input!r}" + input_type = input if type(input) is type else f"of type {type(input).__name__!r}" + input_detail = "" if type(input) is type else f" (given: {input!r})" + msg = f"cannot parse input {input_type} into Polars data type{input_detail}" raise TypeError(msg) from None diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 08aeb53c5674..b815d7d17608 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -83,6 +83,14 @@ def is_temporal(cls) -> bool: # noqa: D102 def is_nested(cls) -> bool: # noqa: D102 ... + @classmethod + def from_python(cls, py_type: PythonDataType) -> PolarsDataType: # noqa: D102 + ... + + @classmethod + def to_python(self) -> PythonDataType: # noqa: D102 + ... + class DataType(metaclass=DataTypeClass): """Base class for all Polars data types.""" @@ -180,6 +188,49 @@ def is_nested(cls) -> bool: """Check whether the data type is a nested type.""" return issubclass(cls, NestedType) + @classmethod + def from_python(cls, py_type: PythonDataType) -> PolarsDataType: + """ + Return the Polars data type corresponding to a given Python type. + + Notes + ----- + Not every Python type has a corresponding Polars data type; in general + you should declare Polars data types explicitly to exactly specify + the desired type and its properties (such as scale/unit). + + Examples + -------- + >>> pl.DataType.from_python(int) + Int64 + >>> pl.DataType.from_python(float) + Float64 + >>> from datetime import tzinfo + >>> pl.DataType.from_python(tzinfo) # doctest: +SKIP + TypeError: cannot parse input into Polars data type + """ + from polars.datatypes._parse import parse_into_dtype + + return parse_into_dtype(py_type) + + @classinstmethod # type: ignore[arg-type] + def to_python(self) -> PythonDataType: + """ + Return the Python type corresponding to this Polars data type. + + Examples + -------- + >>> pl.Int16().to_python() + + >>> pl.Float32().to_python() + + >>> pl.Array(pl.Date(), 10).to_python() + + """ + from polars.datatypes import dtype_to_py_type + + return dtype_to_py_type(self) + class NumericType(DataType): """Base class for numeric data types.""" diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index a965422c7530..1b0806b2ea75 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -19,6 +19,7 @@ Datetime, Decimal, Duration, + Enum, Field, Float32, Float64, @@ -134,55 +135,60 @@ class _DataTypeMappings: @functools.lru_cache # noqa: B019 def DTYPE_TO_FFINAME(self) -> dict[PolarsDataType, str]: return { - Int8: "i8", - Int16: "i16", - Int32: "i32", - Int64: "i64", - UInt8: "u8", - UInt16: "u16", - UInt32: "u32", - UInt64: "u64", - Float32: "f32", - Float64: "f64", - Decimal: "decimal", + Binary: "binary", Boolean: "bool", - String: "str", - List: "list", + Categorical: "categorical", Date: "date", Datetime: "datetime", + Decimal: "decimal", Duration: "duration", - Time: "time", + Float32: "f32", + Float64: "f64", + Int16: "i16", + Int32: "i32", + Int64: "i64", + Int8: "i8", + List: "list", Object: "object", - Categorical: "categorical", + String: "str", Struct: "struct", - Binary: "binary", + Time: "time", + UInt16: "u16", + UInt32: "u32", + UInt64: "u64", + UInt8: "u8", } @property @functools.lru_cache # noqa: B019 def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]: return { - Float64: float, + Array: list, + Binary: bytes, + Boolean: bool, + Date: date, + Datetime: datetime, + Decimal: PyDecimal, + Duration: timedelta, Float32: float, - Int64: int, - Int32: int, + Float64: float, Int16: int, + Int32: int, + Int64: int, Int8: int, + List: list, + Null: None.__class__, + Object: object, String: str, - UInt8: int, + Struct: dict, + Time: time, UInt16: int, UInt32: int, UInt64: int, - Decimal: PyDecimal, - Boolean: bool, - Duration: timedelta, - Datetime: datetime, - Date: date, - Time: time, - Binary: bytes, - List: list, - Array: list, - Null: None.__class__, + UInt8: int, + # the below mappings are appropriate as we restrict cat/enum to strings + Enum: str, + Categorical: str, } @property @@ -190,32 +196,32 @@ def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]: def NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE(self) -> dict[tuple[str, int], PolarsDataType]: return { # (np.dtype().kind, np.dtype().itemsize) + ("M", 8): Datetime, ("b", 1): Boolean, + ("f", 4): Float32, + ("f", 8): Float64, ("i", 1): Int8, ("i", 2): Int16, ("i", 4): Int32, ("i", 8): Int64, + ("m", 8): Duration, ("u", 1): UInt8, ("u", 2): UInt16, ("u", 4): UInt32, ("u", 8): UInt64, - ("f", 4): Float32, - ("f", 8): Float64, - ("m", 8): Duration, - ("M", 8): Datetime, } @property @functools.lru_cache # noqa: B019 def PY_TYPE_TO_ARROW_TYPE(self) -> dict[PythonDataType, pa.lib.DataType]: return { + bool: pa.bool_(), + date: pa.date32(), + datetime: pa.timestamp("us"), float: pa.float64(), int: pa.int64(), str: pa.large_utf8(), - bool: pa.bool_(), - date: pa.date32(), time: pa.time64("us"), - datetime: pa.timestamp("us"), timedelta: pa.duration("us"), None.__class__: pa.null(), } @@ -338,7 +344,7 @@ def maybe_cast(el: Any, dtype: PolarsDataType) -> Any: py_type = dtype_to_py_type(dtype) if not isinstance(el, py_type): try: - el = py_type(el) # type: ignore[call-arg, misc] + el = py_type(el) # type: ignore[call-arg] except Exception: msg = f"cannot convert Python type {type(el).__name__!r} to {dtype!r}" raise TypeError(msg) from None diff --git a/py-polars/polars/dependencies.py b/py-polars/polars/dependencies.py index ce457255bb59..10548da8c904 100644 --- a/py-polars/polars/dependencies.py +++ b/py-polars/polars/dependencies.py @@ -8,11 +8,11 @@ from types import ModuleType from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast +_ALTAIR_AVAILABLE = True _DELTALAKE_AVAILABLE = True _FSSPEC_AVAILABLE = True _GEVENT_AVAILABLE = True _GREAT_TABLES_AVAILABLE = True -_HVPLOT_AVAILABLE = True _HYPOTHESIS_AVAILABLE = True _NUMPY_AVAILABLE = True _PANDAS_AVAILABLE = True @@ -150,11 +150,11 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: import pickle import subprocess + import altair import deltalake import fsspec import gevent import great_tables - import hvplot import hypothesis import numpy import pandas @@ -175,10 +175,10 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: subprocess, _ = _lazy_import("subprocess") # heavy/optional third party libs + altair, _ALTAIR_AVAILABLE = _lazy_import("altair") deltalake, _DELTALAKE_AVAILABLE = _lazy_import("deltalake") fsspec, _FSSPEC_AVAILABLE = _lazy_import("fsspec") great_tables, _GREAT_TABLES_AVAILABLE = _lazy_import("great_tables") - hvplot, _HVPLOT_AVAILABLE = _lazy_import("hvplot") hypothesis, _HYPOTHESIS_AVAILABLE = _lazy_import("hypothesis") numpy, _NUMPY_AVAILABLE = _lazy_import("numpy") pandas, _PANDAS_AVAILABLE = _lazy_import("pandas") @@ -301,11 +301,11 @@ def import_optional( "pickle", "subprocess", # lazy-load third party libs + "altair", "deltalake", "fsspec", "gevent", "great_tables", - "hvplot", "numpy", "pandas", "pydantic", @@ -318,11 +318,11 @@ def import_optional( "_check_for_pyarrow", "_check_for_pydantic", # exported flags/guards + "_ALTAIR_AVAILABLE", "_DELTALAKE_AVAILABLE", "_PYICEBERG_AVAILABLE", "_FSSPEC_AVAILABLE", "_GEVENT_AVAILABLE", - "_HVPLOT_AVAILABLE", "_HYPOTHESIS_AVAILABLE", "_NUMPY_AVAILABLE", "_PANDAS_AVAILABLE", diff --git a/py-polars/polars/expr/binary.py b/py-polars/polars/expr/binary.py index cac394aa457a..7ea6dc4d79ea 100644 --- a/py-polars/polars/expr/binary.py +++ b/py-polars/polars/expr/binary.py @@ -257,15 +257,20 @@ def size(self, unit: SizeUnit = "b") -> Expr: r""" Get the size of binary values in the given unit. + Parameters + ---------- + unit : {'b', 'kb', 'mb', 'gb', 'tb'} + Scale the returned size to the given unit. + Returns ------- Expr - Expression of data type :class:`UInt32`. + Expression of data type :class:`UInt32` or `Float64`. Examples -------- >>> from os import urandom - >>> df = pl.DataFrame({"data": [urandom(n) for n in (512, 256, 2560, 1024)]}) + >>> df = pl.DataFrame({"data": [urandom(n) for n in (512, 256, 1024)]}) >>> df.with_columns( # doctest: +IGNORE_RESULT ... n_bytes=pl.col("data").bin.size(), ... n_kilobytes=pl.col("data").bin.size("kb"), @@ -278,7 +283,6 @@ def size(self, unit: SizeUnit = "b") -> Expr: ╞═════════════════════════════════╪═════════╪═════════════╡ │ b"y?~B\x83\xf4V\x07\xd3\xfb\xb… ┆ 512 ┆ 0.5 │ │ b"\xee$4@f\xc14\x07\x8e\x88\x1… ┆ 256 ┆ 0.25 │ - │ b"~\x17\x9c\xb1\xf4\xdb?\xe9\x… ┆ 2560 ┆ 2.5 │ │ b"\x80\xbd\xb9nEq;2\x99$\xf9\x… ┆ 1024 ┆ 1.0 │ └─────────────────────────────────┴─────────┴─────────────┘ """ diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index cdf6ccb6516f..9a03b46b12d3 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -284,10 +284,12 @@ def round(self, every: str | dt.timedelta | IntoExprColumn) -> Expr: This functionality is considered **unstable**. It may be changed at any point without it being considered a breaking change. - Each date/datetime in the first half of the interval - is mapped to the start of its bucket. - Each date/datetime in the second half of the interval - is mapped to the end of its bucket. + - Each date/datetime in the first half of the interval + is mapped to the start of its bucket. + - Each date/datetime in the second half of the interval + is mapped to the end of its bucket. + - Half-way points are mapped to the start of their bucket. + Ambiguous results are localised using the DST offset of the original timestamp - for example, rounding `'2022-11-06 01:20:00 CST'` by `'1h'` results in `'2022-11-06 01:00:00 CST'`, whereas rounding `'2022-11-06 01:20:00 CDT'` by diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 69e8a375a4dd..e24297aa93ee 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -681,9 +681,9 @@ def alias(self, name: str) -> Expr: See Also -------- - map - prefix - suffix + name.map + name.prefix + name.suffix Examples -------- @@ -2494,7 +2494,7 @@ def sort_by( ) def gather( - self, indices: int | Sequence[int] | Expr | Series | np.ndarray[Any, Any] + self, indices: int | Sequence[int] | IntoExpr | Series | np.ndarray[Any, Any] ) -> Expr: """ Take values by index. @@ -2541,8 +2541,10 @@ def gather( │ two ┆ [4, 99] │ └───────┴───────────┘ """ - if isinstance(indices, Sequence) or ( - _check_for_numpy(indices) and isinstance(indices, np.ndarray) + if ( + isinstance(indices, Sequence) + and not isinstance(indices, str) + or (_check_for_numpy(indices) and isinstance(indices, np.ndarray)) ): indices_lit = F.lit(pl.Series("", indices, dtype=Int64))._pyexpr else: @@ -4298,14 +4300,14 @@ def map_batches( Dtype of the output Series. If not set, the dtype will be inferred based on the first non-null value that is returned by the function. - is_elementwise - If set to true this can run in the streaming engine, but may yield - incorrect results in group-by. Ensure you know what you are doing! agg_list Aggregate the values of the expression into a list before applying the function. This parameter only works in a group-by context. The function will be invoked only once on a list of groups, rather than once per group. + is_elementwise + If set to true this can run in the streaming engine, but may yield + incorrect results in group-by. Ensure you know what you are doing! returns_scalar If the function returns a scalar, by default it will be wrapped in a list in the output, since the assumption is that the function @@ -4743,7 +4745,7 @@ def flatten(self) -> Expr: """ Flatten a list or string column. - Alias for :func:`polars.expr.list.ExprListNameSpace.explode`. + Alias for :func:`Expr.list.explode`. Examples -------- @@ -4883,7 +4885,7 @@ def head(self, n: int | Expr = 10) -> Expr: Examples -------- >>> df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6, 7]}) - >>> df.head(3) + >>> df.select(pl.col("foo").head(3)) shape: (3, 1) ┌─────┐ │ foo │ @@ -4909,7 +4911,7 @@ def tail(self, n: int | Expr = 10) -> Expr: Examples -------- >>> df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6, 7]}) - >>> df.tail(3) + >>> df.select(pl.col("foo").tail(3)) shape: (3, 1) ┌─────┐ │ foo │ @@ -4940,7 +4942,7 @@ def limit(self, n: int | Expr = 10) -> Expr: Examples -------- >>> df = pl.DataFrame({"foo": [1, 2, 3, 4, 5, 6, 7]}) - >>> df.limit(3) + >>> df.select(pl.col("foo").limit(3)) shape: (3, 1) ┌─────┐ │ foo │ @@ -8744,30 +8746,31 @@ def upper_bound(self) -> Expr: def sign(self) -> Expr: """ - Compute the element-wise indication of the sign. + Compute the element-wise sign function on numeric types. - The returned values can be -1, 0, or 1: + The returned value is computed as follows: - * -1 if x < 0. - * 0 if x == 0. - * 1 if x > 0. + * -1 if x < 0. + * 1 if x > 0. + * x otherwise (typically 0, but could be NaN if the input is). - (null values are preserved as-is). + Null values are preserved as-is, and the dtype of the input is preserved. Examples -------- - >>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, None]}) - >>> df.select(pl.col("a").sign()) - shape: (5, 1) + >>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, float("nan"), None]}) + >>> df.select(pl.col.a.sign()) + shape: (6, 1) ┌──────┐ │ a │ │ --- │ - │ i64 │ + │ f64 │ ╞══════╡ - │ -1 │ - │ 0 │ - │ 0 │ - │ 1 │ + │ -1.0 │ + │ -0.0 │ + │ 0.0 │ + │ 1.0 │ + │ NaN │ │ null │ └──────┘ """ @@ -9211,6 +9214,9 @@ def shuffle(self, seed: int | None = None) -> Expr: """ Shuffle the contents of this expression. + Note this is shuffled independently of any other column or Expression. If you + want each row to stay the same use df.sample(shuffle=True) + Parameters ---------- seed @@ -9305,7 +9311,7 @@ def ewm_mean( ignore_nulls: bool = False, ) -> Expr: r""" - Exponentially-weighted moving average. + Compute exponentially-weighted moving average. Parameters ---------- @@ -9320,11 +9326,11 @@ def ewm_mean( .. math:: \alpha = \frac{2}{\theta + 1} \; \forall \; \theta \geq 1 half_life - Specify decay in terms of half-life, :math:`\lambda`, with + Specify decay in terms of half-life, :math:`\tau`, with .. math:: - \alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \lambda } \right\} \; - \forall \; \lambda > 0 + \alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \; + \forall \; \tau > 0 alpha Specify smoothing factor alpha directly, :math:`0 < \alpha \leq 1`. adjust @@ -9387,20 +9393,21 @@ def ewm_mean_by( half_life: str | timedelta, ) -> Expr: r""" - Calculate time-based exponentially weighted moving average. + Compute time-based exponentially weighted moving average. - Given observations :math:`x_1, x_2, \ldots, x_n` at times - :math:`t_1, t_2, \ldots, t_n`, the EWMA is calculated as + Given observations :math:`x_0, x_1, \ldots, x_{n-1}` at times + :math:`t_0, t_1, \ldots, t_{n-1}`, the EWMA is calculated as .. math:: y_0 &= x_0 - \alpha_i &= \exp(-\lambda(t_i - t_{i-1})) + \alpha_i &= 1 - \exp \left\{ \frac{ -\ln(2)(t_i-t_{i-1}) } + { \tau } \right\} y_i &= \alpha_i x_i + (1 - \alpha_i) y_{i-1}; \quad i > 0 - where :math:`\lambda` equals :math:`\ln(2) / \text{half_life}`. + where :math:`\tau` is the `half_life`. Parameters ---------- @@ -9484,7 +9491,7 @@ def ewm_std( ignore_nulls: bool = False, ) -> Expr: r""" - Exponentially-weighted moving standard deviation. + Compute exponentially-weighted moving standard deviation. Parameters ---------- @@ -9575,7 +9582,7 @@ def ewm_var( ignore_nulls: bool = False, ) -> Expr: r""" - Exponentially-weighted moving variance. + Compute exponentially-weighted moving variance. Parameters ---------- @@ -10240,7 +10247,12 @@ def replace( old, new, default=default, return_dtype=return_dtype ) - if new is no_default and isinstance(old, Mapping): + if new is no_default: + if not isinstance(old, Mapping): + msg = ( + "`new` argument is required if `old` argument is not a Mapping type" + ) + raise TypeError(msg) new = pl.Series(old.values()) old = pl.Series(old.keys()) else: @@ -10250,7 +10262,7 @@ def replace( new = pl.Series(new) old = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type] - new = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type] + new = parse_into_expression(new, str_as_lit=True) result = self._from_pyexpr(self._pyexpr.replace(old, new)) @@ -10431,7 +10443,12 @@ def replace_strict( │ 3 ┆ 1.0 ┆ 10.0 │ └─────┴─────┴──────────┘ """ # noqa: W505 - if new is no_default and isinstance(old, Mapping): + if new is no_default: + if not isinstance(old, Mapping): + msg = ( + "`new` argument is required if `old` argument is not a Mapping type" + ) + raise TypeError(msg) new = pl.Series(old.values()) old = pl.Series(old.keys()) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 390904997697..5655b58c86cb 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1005,7 +1005,7 @@ def explode(self) -> Expr: See Also -------- - ExprNameSpace.reshape: Reshape this Expr to a flat Series or a Series of Lists. + Expr.reshape: Reshape this Expr to a flat Series or a Series of Lists. Examples -------- diff --git a/py-polars/polars/expr/name.py b/py-polars/polars/expr/name.py index 9c730d2d3206..8b6fe24d8dea 100644 --- a/py-polars/polars/expr/name.py +++ b/py-polars/polars/expr/name.py @@ -286,17 +286,22 @@ def to_uppercase(self) -> Expr: def map_fields(self, function: Callable[[str], str]) -> Expr: """ - Rename fields of a struct by mapping a function over the field name. + Rename fields of a struct by mapping a function over the field name(s). Notes ----- - This only take effects for struct. + This only takes effect for struct columns. Parameters ---------- function Function that maps a field name to a new name. + See Also + -------- + prefix_fields + suffix_fields + Examples -------- >>> df = pl.DataFrame({"x": {"a": 1, "b": 2}}) @@ -307,16 +312,21 @@ def map_fields(self, function: Callable[[str], str]) -> Expr: def prefix_fields(self, prefix: str) -> Expr: """ - Add a prefix to all fields name of a struct. + Add a prefix to all field names of a struct. Notes ----- - This only take effects for struct. + This only takes effect for struct columns. Parameters ---------- prefix - Prefix to add to the filed name + Prefix to add to the field name. + + See Also + -------- + map_fields + suffix_fields Examples -------- @@ -328,16 +338,21 @@ def prefix_fields(self, prefix: str) -> Expr: def suffix_fields(self, suffix: str) -> Expr: """ - Add a suffix to all fields name of a struct. + Add a suffix to all field names of a struct. Notes ----- - This only take effects for struct. + This only takes effect for struct columns. Parameters ---------- suffix - Suffix to add to the filed name + Suffix to add to the field name. + + See Also + -------- + map_fields + prefix_fields Examples -------- diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index c64662458d53..351c1cdd8565 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Mapping import polars._reexport as pl from polars import functions as F @@ -11,7 +11,7 @@ ) from polars._utils.parse import parse_into_expression from polars._utils.unstable import unstable -from polars._utils.various import find_stacklevel +from polars._utils.various import find_stacklevel, no_default from polars._utils.wrap import wrap_expr from polars.datatypes import Date, Datetime, Time, parse_into_dtype from polars.datatypes.constants import N_INFER_DEFAULT @@ -28,6 +28,7 @@ TimeUnit, TransferEncoding, ) + from polars._utils.various import NoDefault class ExprStringNameSpace: @@ -447,7 +448,7 @@ def len_chars(self) -> Expr: def to_uppercase(self) -> Expr: """ - Transform to uppercase variant. + Modify strings to their uppercase equivalent. Examples -------- @@ -467,7 +468,7 @@ def to_uppercase(self) -> Expr: def to_lowercase(self) -> Expr: """ - Transform to lowercase variant. + Modify strings to their lowercase equivalent. Examples -------- @@ -487,22 +488,37 @@ def to_lowercase(self) -> Expr: def to_titlecase(self) -> Expr: """ - Transform to titlecase variant. + Modify strings to their titlecase equivalent. + + Notes + ----- + This is a form of case transform where the first letter of each word is + capitalized, with the rest of the word in lowercase. Non-alphanumeric + characters define the word boundaries. Examples -------- >>> df = pl.DataFrame( - ... {"sing": ["welcome to my world", "THERE'S NO TURNING BACK"]} + ... { + ... "quotes": [ + ... "'e.t. phone home'", + ... "you talkin' to me?", + ... "to infinity,and BEYOND!", + ... ] + ... } ... ) - >>> df.with_columns(foo_title=pl.col("sing").str.to_titlecase()) - shape: (2, 2) + >>> df.with_columns( + ... quotes_title=pl.col("quotes").str.to_titlecase(), + ... ) + shape: (3, 2) ┌─────────────────────────┬─────────────────────────┐ - │ sing ┆ foo_title │ + │ quotes ┆ quotes_title │ │ --- ┆ --- │ │ str ┆ str │ ╞═════════════════════════╪═════════════════════════╡ - │ welcome to my world ┆ Welcome To My World │ - │ THERE'S NO TURNING BACK ┆ There's No Turning Back │ + │ 'e.t. phone home' ┆ 'E.T. Phone Home' │ + │ you talkin' to me? ┆ You Talkin' To Me? │ + │ to infinity,and BEYOND! ┆ To Infinity,And Beyond! │ └─────────────────────────┴─────────────────────────┘ """ return wrap_expr(self._pyexpr.str_to_titlecase()) @@ -908,7 +924,7 @@ def contains( self, pattern: str | Expr, *, literal: bool = False, strict: bool = True ) -> Expr: """ - Check if string contains a substring that matches a pattern. + Check if the string contains a substring that matches a pattern. Parameters ---------- @@ -1019,7 +1035,7 @@ def find( See Also -------- - contains : Check if string contains a substring that matches a regex. + contains : Check if the string contains a substring that matches a pattern. Examples -------- @@ -1078,7 +1094,7 @@ def ends_with(self, suffix: str | Expr) -> Expr: See Also -------- - contains : Check if string contains a substring that matches a regex. + contains : Check if the string contains a substring that matches a pattern. starts_with : Check if string values start with a substring. Examples @@ -1141,7 +1157,7 @@ def starts_with(self, prefix: str | Expr) -> Expr: See Also -------- - contains : Check if string contains a substring that matches a regex. + contains : Check if the string contains a substring that matches a pattern. ends_with : Check if string values end with a substring. Examples @@ -2385,9 +2401,9 @@ def contains_any( self, patterns: IntoExpr, *, ascii_case_insensitive: bool = False ) -> Expr: """ - Use the aho-corasick algorithm to find matches. + Use the Aho-Corasick algorithm to find matches. - This version determines if any of the patterns find a match. + Determines if any of the patterns are contained in the string. Parameters ---------- @@ -2398,6 +2414,11 @@ def contains_any( When this option is enabled, searching will be performed without respect to case for ASCII letters (a-z and A-Z) only. + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + Examples -------- >>> _ = pl.Config.set_fmt_str_lengths(100) @@ -2433,29 +2454,75 @@ def contains_any( def replace_many( self, - patterns: IntoExpr, - replace_with: IntoExpr, + patterns: IntoExpr | Mapping[str, str], + replace_with: IntoExpr | NoDefault = no_default, *, ascii_case_insensitive: bool = False, ) -> Expr: """ - - Use the aho-corasick algorithm to replace many matches. + Use the Aho-Corasick algorithm to replace many matches. Parameters ---------- patterns String patterns to search and replace. + Accepts expression input. Strings are parsed as column names, and other + non-expression inputs are parsed as literals. Also accepts a mapping of + patterns to their replacement as syntactic sugar for + `replace_many(pl.Series(mapping.keys()), pl.Series(mapping.values()))`. replace_with Strings to replace where a pattern was a match. - This can be broadcast, so it supports many:one and many:many. + Accepts expression input. Non-expression inputs are parsed as literals. + Length must match the length of `patterns` or have length 1. This can be + broadcasted, so it supports many:one and many:many. ascii_case_insensitive Enable ASCII-aware case-insensitive matching. When this option is enabled, searching will be performed without respect to case for ASCII letters (a-z and A-Z) only. + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + Examples -------- + Replace many patterns by passing sequences of equal length to the `patterns` and + `replace_with` parameters. + + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> _ = pl.Config.set_tbl_width_chars(110) + >>> df = pl.DataFrame( + ... { + ... "lyrics": [ + ... "Everybody wants to rule the world", + ... "Tell me what you want, what you really really want", + ... "Can you feel the love tonight", + ... ] + ... } + ... ) + >>> df.with_columns( + ... pl.col("lyrics") + ... .str.replace_many( + ... ["me", "you"], + ... ["you", "me"], + ... ) + ... .alias("confusing") + ... ) + shape: (3, 2) + ┌────────────────────────────────────────────────────┬───────────────────────────────────────────────────┐ + │ lyrics ┆ confusing │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞════════════════════════════════════════════════════╪═══════════════════════════════════════════════════╡ + │ Everybody wants to rule the world ┆ Everybody wants to rule the world │ + │ Tell me what you want, what you really really want ┆ Tell you what me want, what me really really want │ + │ Can you feel the love tonight ┆ Can me feel the love tonight │ + └────────────────────────────────────────────────────┴───────────────────────────────────────────────────┘ + + Broadcast a replacement for many patterns by passing a string or a sequence of + length 1 to the `replace_with` parameter. + >>> _ = pl.Config.set_fmt_str_lengths(100) >>> df = pl.DataFrame( ... { @@ -2484,27 +2551,50 @@ def replace_many( │ Tell me what you want, what you really really want ┆ Tell what want, what really really want │ │ Can you feel the love tonight ┆ Can feel the love tonight │ └────────────────────────────────────────────────────┴────────────────────────────────────────────┘ + + Passing a mapping with patterns and replacements is also supported as syntactic + sugar. + + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> _ = pl.Config.set_tbl_width_chars(110) + >>> df = pl.DataFrame( + ... { + ... "lyrics": [ + ... "Everybody wants to rule the world", + ... "Tell me what you want, what you really really want", + ... "Can you feel the love tonight", + ... ] + ... } + ... ) + >>> mapping = {"me": "you", "you": "me", "want": "need"} >>> df.with_columns( - ... pl.col("lyrics") - ... .str.replace_many( - ... ["me", "you"], - ... ["you", "me"], - ... ) - ... .alias("confusing") - ... ) # doctest: +IGNORE_RESULT + ... pl.col("lyrics").str.replace_many(mapping).alias("confusing") + ... ) shape: (3, 2) ┌────────────────────────────────────────────────────┬───────────────────────────────────────────────────┐ │ lyrics ┆ confusing │ │ --- ┆ --- │ │ str ┆ str │ ╞════════════════════════════════════════════════════╪═══════════════════════════════════════════════════╡ - │ Everybody wants to rule the world ┆ Everybody wants to rule the world │ - │ Tell me what you want, what you really really want ┆ Tell you what me want, what me really really want │ + │ Everybody wants to rule the world ┆ Everybody needs to rule the world │ + │ Tell me what you want, what you really really want ┆ Tell you what me need, what me really really need │ │ Can you feel the love tonight ┆ Can me feel the love tonight │ └────────────────────────────────────────────────────┴───────────────────────────────────────────────────┘ """ # noqa: W505 + if replace_with is no_default: + if not isinstance(patterns, Mapping): + msg = "`replace_with` argument is required if `patterns` argument is not a Mapping type" + raise TypeError(msg) + # Early return in case of an empty mapping. + if not patterns: + return wrap_expr(self._pyexpr) + replace_with = pl.Series(patterns.values()) + patterns = pl.Series(patterns.keys()) + patterns = parse_into_expression( - patterns, str_as_lit=False, list_as_series=True + patterns, # type: ignore[arg-type] + str_as_lit=False, + list_as_series=True, ) replace_with = parse_into_expression( replace_with, str_as_lit=True, list_as_series=True @@ -2524,8 +2614,7 @@ def extract_many( overlapping: bool = False, ) -> Expr: """ - - Use the aho-corasick algorithm to extract many matches. + Use the Aho-Corasick algorithm to extract many matches. Parameters ---------- @@ -2538,6 +2627,11 @@ def extract_many( overlapping Whether matches may overlap. + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + Examples -------- >>> _ = pl.Config.set_fmt_str_lengths(100) diff --git a/py-polars/polars/functions/eager.py b/py-polars/polars/functions/eager.py index a841520be8c5..e8cbb00e3dca 100644 --- a/py-polars/polars/functions/eager.py +++ b/py-polars/polars/functions/eager.py @@ -273,7 +273,14 @@ def join_func( idx_y: tuple[int, LazyFrame], ) -> tuple[int, LazyFrame]: (_, x), (y_idx, y) = idx_x, idx_y - return y_idx, x.join(y, how=how, on=align_on, suffix=f":{y_idx}", coalesce=True) + return y_idx, x.join( + y, + how=how, + on=align_on, + suffix=f":{y_idx}", + join_nulls=True, + coalesce=True, + ) joined = reduce(join_func, idx_frames)[1].sort(by=align_on, descending=descending) if post_align_collect: diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index a7138f19bfd5..8ba891d70a59 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -180,9 +180,6 @@ def cum_count(*columns: str, reverse: bool = False) -> Expr: This function is syntactic sugar for `col(columns).cum_count()`. - If no arguments are passed, returns the cumulative count of a context. - Rows containing null values count towards the result. - Parameters ---------- *columns diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index 92d43f88b726..8853963cbeed 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -6,14 +6,8 @@ from typing import TYPE_CHECKING, Any import polars._reexport as pl -from polars._utils.convert import ( - date_to_int, - datetime_to_int, - time_to_int, - timedelta_to_int, -) from polars._utils.wrap import wrap_expr -from polars.datatypes import Date, Datetime, Duration, Enum, Time +from polars.datatypes import Date, Datetime, Duration, Enum from polars.dependencies import _check_for_numpy from polars.dependencies import numpy as np @@ -78,44 +72,64 @@ def lit( time_unit: TimeUnit if isinstance(value, datetime): + if dtype == Date: + return wrap_expr(plr.lit(value.date(), allow_object=False)) + + # parse time unit if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: time_unit = tu # type: ignore[assignment] else: time_unit = "us" - time_zone: str | None = getattr(dtype, "time_zone", None) - if (tzinfo := value.tzinfo) is not None: - tzinfo_str = str(tzinfo) - if time_zone is not None and time_zone != tzinfo_str: - msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})" + # parse time zone + dtype_tz = getattr(dtype, "time_zone", None) + value_tz = value.tzinfo + if value_tz is None: + tz = dtype_tz + else: + if dtype_tz is None: + # value has time zone, but dtype does not: keep value time zone + tz = str(value_tz) + elif str(value_tz) == dtype_tz: + # dtype and value both have same time zone + tz = str(value_tz) + else: + # value has time zone that differs from dtype time zone + msg = ( + f"time zone of dtype ({dtype_tz!r}) differs from time zone of " + f"value ({value_tz!r})" + ) raise TypeError(msg) - time_zone = tzinfo_str dt_utc = value.replace(tzinfo=timezone.utc) - dt_int = datetime_to_int(dt_utc, time_unit) - expr = lit(dt_int).cast(Datetime(time_unit)) - if time_zone is not None: + expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(Datetime(time_unit)) + if tz is not None: expr = expr.dt.replace_time_zone( - time_zone, ambiguous="earliest" if value.fold == 0 else "latest" + tz, ambiguous="earliest" if value.fold == 0 else "latest" ) return expr elif isinstance(value, timedelta): - if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: - time_unit = tu # type: ignore[assignment] - else: - time_unit = "us" - - td_int = timedelta_to_int(value, time_unit) - return lit(td_int).cast(Duration(time_unit)) + expr = wrap_expr(plr.lit(value, allow_object=False)) + if dtype is not None and (tu := getattr(dtype, "time_unit", None)) is not None: + expr = expr.cast(Duration(tu)) + return expr elif isinstance(value, time): - time_int = time_to_int(value) - return lit(time_int).cast(Time) + return wrap_expr(plr.lit(value, allow_object=False)) elif isinstance(value, date): - date_int = date_to_int(value) - return lit(date_int).cast(Date) + if dtype == Datetime: + time_unit = getattr(dtype, "time_unit", "us") or "us" + dt_utc = datetime(value.year, value.month, value.day) + expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast( + Datetime(time_unit) + ) + if (time_zone := getattr(dtype, "time_zone", None)) is not None: + expr = expr.dt.replace_time_zone(str(time_zone)) + return expr + else: + return wrap_expr(plr.lit(value, allow_object=False)) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/polars/functions/range/date_range.py b/py-polars/polars/functions/range/date_range.py index 1ed407194a3d..d0fac37a09b1 100644 --- a/py-polars/polars/functions/range/date_range.py +++ b/py-polars/polars/functions/range/date_range.py @@ -140,6 +140,36 @@ def date_range( 1985-01-07 1985-01-09 ] + + Omit `eager=True` if you want to use `date_range` as an expression: + + >>> df = pl.DataFrame( + ... { + ... "date": [ + ... date(2024, 1, 1), + ... date(2024, 1, 2), + ... date(2024, 1, 1), + ... date(2024, 1, 3), + ... ], + ... "key": ["one", "one", "two", "two"], + ... } + ... ) + >>> result = ( + ... df.group_by("key") + ... .agg(pl.date_range(pl.col("date").min(), pl.col("date").max())) + ... .sort("key") + ... ) + >>> with pl.Config(fmt_str_lengths=50): + ... print(result) + shape: (2, 2) + ┌─────┬──────────────────────────────────────┐ + │ key ┆ date │ + │ --- ┆ --- │ + │ str ┆ list[date] │ + ╞═════╪══════════════════════════════════════╡ + │ one ┆ [2024-01-01, 2024-01-02] │ + │ two ┆ [2024-01-01, 2024-01-02, 2024-01-03] │ + └─────┴──────────────────────────────────────┘ """ interval = parse_interval_argument(interval) diff --git a/py-polars/polars/functions/range/datetime_range.py b/py-polars/polars/functions/range/datetime_range.py index 3fe2c2767481..9e994917ebbc 100644 --- a/py-polars/polars/functions/range/datetime_range.py +++ b/py-polars/polars/functions/range/datetime_range.py @@ -177,6 +177,36 @@ def datetime_range( 2022-02-01 00:00:00 EST 2022-03-01 00:00:00 EST ] + + Omit `eager=True` if you want to use `datetime_range` as an expression: + + >>> df = pl.DataFrame( + ... { + ... "date": [ + ... date(2024, 1, 1), + ... date(2024, 1, 2), + ... date(2024, 1, 1), + ... date(2024, 1, 3), + ... ], + ... "key": ["one", "one", "two", "two"], + ... } + ... ) + >>> result = ( + ... df.group_by("key") + ... .agg(pl.datetime_range(pl.col("date").min(), pl.col("date").max())) + ... .sort("key") + ... ) + >>> with pl.Config(fmt_str_lengths=70): + ... print(result) + shape: (2, 2) + ┌─────┬─────────────────────────────────────────────────────────────────┐ + │ key ┆ date │ + │ --- ┆ --- │ + │ str ┆ list[datetime[μs]] │ + ╞═════╪═════════════════════════════════════════════════════════════════╡ + │ one ┆ [2024-01-01 00:00:00, 2024-01-02 00:00:00] │ + │ two ┆ [2024-01-01 00:00:00, 2024-01-02 00:00:00, 2024-01-03 00:00:00] │ + └─────┴─────────────────────────────────────────────────────────────────┘ """ interval = parse_interval_argument(interval) if time_unit is None and "ns" in interval: diff --git a/py-polars/polars/io/csv/batched_reader.py b/py-polars/polars/io/csv/batched_reader.py index 57cca5d366d3..f13efb7aa3b3 100644 --- a/py-polars/polars/io/csv/batched_reader.py +++ b/py-polars/polars/io/csv/batched_reader.py @@ -121,7 +121,7 @@ def next_batches(self, n: int) -> list[DataFrame] | None: Examples -------- >>> reader = pl.read_csv_batched( - ... "./tpch/tables_scale_100/lineitem.tbl", + ... "./pdsh/tables_scale_100/lineitem.tbl", ... separator="|", ... try_parse_dates=True, ... ) # doctest: +SKIP diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 7b3d3a91dbf3..ceba49391560 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -11,6 +11,7 @@ from polars._utils.deprecation import deprecate_renamed_parameter from polars._utils.various import ( _process_null_values, + is_path_or_str_sequence, is_str_sequence, normalize_filepath, ) @@ -443,6 +444,8 @@ def read_csv( # * The `storage_options` configuration keys are different between # fsspec and object_store (would require a breaking change) ): + source = normalize_filepath(v, check_not_directory=False) + if schema_overrides_is_list: msg = "passing a list to `schema_overrides` is unsupported for hf:// paths" raise ValueError(msg) @@ -451,7 +454,7 @@ def read_csv( raise ValueError(msg) lf = _scan_csv_impl( - source, # type: ignore[arg-type] + source, has_header=has_header, separator=separator, comment_prefix=comment_prefix, @@ -827,7 +830,7 @@ def read_csv_batched( Examples -------- >>> reader = pl.read_csv_batched( - ... "./tpch/tables_scale_100/lineitem.tbl", + ... "./pdsh/tables_scale_100/lineitem.tbl", ... separator="|", ... try_parse_dates=True, ... ) # doctest: +SKIP @@ -984,7 +987,16 @@ def read_csv_batched( @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_csv( - source: str | Path | list[str] | list[Path], + source: str + | Path + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | list[bytes], *, has_header: bool = True, separator: str = ",", @@ -1232,7 +1244,7 @@ def with_column_names(cols: list[str]) -> list[str]: if isinstance(source, (str, Path)): source = normalize_filepath(source, check_not_directory=False) - else: + elif is_path_or_str_sequence(source, allow_str=False): source = [ normalize_filepath(source, check_not_directory=False) for source in source ] @@ -1276,7 +1288,15 @@ def with_column_names(cols: list[str]) -> list[str]: def _scan_csv_impl( - source: str | list[str] | list[Path], + source: str + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | list[bytes], *, has_header: bool = True, separator: str = ",", @@ -1329,8 +1349,8 @@ def _scan_csv_impl( storage_options = None pylf = PyLazyFrame.new_from_csv( - path=source, - paths=sources, + source, + sources, separator=separator, has_header=has_header, ignore_errors=ignore_errors, diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index ef044d70d139..bb56d48c6b2f 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -9,10 +9,12 @@ from polars import functions as F from polars._utils.various import parse_version from polars.convert import from_arrow -from polars.datatypes import ( - N_INFER_DEFAULT, +from polars.datatypes import N_INFER_DEFAULT +from polars.exceptions import ( + DuplicateError, + ModuleUpgradeRequiredError, + UnsuitableSQLError, ) -from polars.exceptions import ModuleUpgradeRequiredError, UnsuitableSQLError from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy from polars.io.database._inference import _infer_dtype_from_cursor_description @@ -266,25 +268,25 @@ def _from_rows( if hasattr(self.result, "fetchall"): if self.driver_name == "sqlalchemy": if hasattr(self.result, "cursor"): - cursor_desc = { - d[0]: d[1:] for d in self.result.cursor.description - } + cursor_desc = [ + (d[0], d[1:]) for d in self.result.cursor.description + ] elif hasattr(self.result, "_metadata"): - cursor_desc = {k: None for k in self.result._metadata.keys} + cursor_desc = [(k, None) for k in self.result._metadata.keys] else: msg = f"Unable to determine metadata from query result; {self.result!r}" raise ValueError(msg) elif hasattr(self.result, "description"): - cursor_desc = {d[0]: d[1:] for d in self.result.description} + cursor_desc = [(d[0], d[1:]) for d in self.result.description] else: - cursor_desc = {} + cursor_desc = [] schema_overrides = self._inject_type_overrides( description=cursor_desc, schema_overrides=(schema_overrides or {}), ) - result_columns = list(cursor_desc) + result_columns = [nm for nm, _ in cursor_desc] frames = ( DataFrame( data=rows, @@ -307,7 +309,7 @@ def _from_rows( def _inject_type_overrides( self, - description: dict[str, Any], + description: list[tuple[str, Any]], schema_overrides: SchemaDict, ) -> SchemaDict: """ @@ -320,11 +322,16 @@ def _inject_type_overrides( We currently only do the additional inference from string/python type values. (Further refinement will require per-driver module knowledge and lookups). """ - for nm, desc in description.items(): - if desc is not None and nm not in schema_overrides: + dupe_check = set() + for nm, desc in description: + if nm in dupe_check: + msg = f"column {nm!r} appears more than once in the query/result cursor" + raise DuplicateError(msg) + elif desc is not None and nm not in schema_overrides: dtype = _infer_dtype_from_cursor_description(self.cursor, desc) if dtype is not None: schema_overrides[nm] = dtype # type: ignore[index] + dupe_check.add(nm) return schema_overrides @@ -384,7 +391,7 @@ def _normalise_cursor(self, conn: Any) -> Cursor: return conn.engine.raw_connection().cursor() elif conn.engine.driver == "duckdb_engine": self.driver_name = "duckdb" - return conn.engine.raw_connection().driver_connection.c + return conn.engine.raw_connection().driver_connection elif self._is_alchemy_engine(conn): # note: if we create it, we can close it self.can_close_cursor = True diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 4443c31d513f..43fbc8136de2 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -9,6 +9,7 @@ import polars.functions as F from polars._utils.deprecation import deprecate_renamed_parameter from polars._utils.various import ( + is_path_or_str_sequence, is_str_sequence, normalize_filepath, ) @@ -111,9 +112,8 @@ def read_ipc( raise ValueError(msg) lf = scan_ipc( - source, # type: ignore[arg-type] + source, n_rows=n_rows, - memory_map=memory_map, storage_options=storage_options, row_index_name=row_index_name, row_index_offset=row_index_offset, @@ -188,7 +188,6 @@ def _read_ipc_impl( rechunk=rechunk, row_index_name=row_index_name, row_index_offset=row_index_offset, - memory_map=memory_map, ) if columns is None: df = scan.collect() @@ -346,7 +345,14 @@ def read_ipc_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, DataTyp @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_ipc( - source: str | Path | list[str] | list[Path], + source: str + | Path + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[bytes]] + | list[bytes], *, n_rows: int | None = None, cache: bool = True, @@ -426,15 +432,23 @@ def scan_ipc( include_file_paths Include the path of the source file(s) as a column with this name. """ + sources: list[str] | list[Path] | list[IO[bytes]] | list[bytes] = [] if isinstance(source, (str, Path)): source = normalize_filepath(source, check_not_directory=False) - sources = [] - else: - sources = [ - normalize_filepath(source, check_not_directory=False) for source in source - ] + elif isinstance(source, list): + if is_path_or_str_sequence(source): + sources = [ + normalize_filepath(source, check_not_directory=False) + for source in source + ] + else: + sources = source + source = None # type: ignore[assignment] + # Memory Mapping is now a no-op + _ = memory_map + pylf = PyLazyFrame.new_from_ipc( source, sources, @@ -442,7 +456,6 @@ def scan_ipc( cache, rechunk, parse_row_index_args(row_index_name, row_index_offset), - memory_map=memory_map, cloud_options=storage_options, retries=retries, file_cache_ttl=file_cache_ttl, diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index e8eccca53ccd..cd9ea92bf3c0 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -3,10 +3,10 @@ import contextlib from io import BytesIO, StringIO from pathlib import Path -from typing import TYPE_CHECKING, Any, Sequence +from typing import IO, TYPE_CHECKING, Any, Sequence from polars._utils.deprecation import deprecate_renamed_parameter -from polars._utils.various import normalize_filepath +from polars._utils.various import is_path_or_str_sequence, normalize_filepath from polars._utils.wrap import wrap_df, wrap_ldf from polars.datatypes import N_INFER_DEFAULT from polars.io._utils import parse_row_index_args @@ -145,7 +145,7 @@ def read_ndjson( return df return scan_ndjson( - source, # type: ignore[arg-type] + source, schema=schema, schema_overrides=schema_overrides, infer_schema_length=infer_schema_length, @@ -166,7 +166,16 @@ def read_ndjson( @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_ndjson( - source: str | Path | list[str] | list[Path], + source: str + | Path + | IO[str] + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[str]] + | list[IO[bytes]] + | bytes, *, schema: SchemaDefinition | None = None, schema_overrides: SchemaDefinition | None = None, @@ -247,14 +256,20 @@ def scan_ndjson( include_file_paths Include the path of the source file(s) as a column with this name. """ + sources: list[str] | list[Path] | list[IO[str]] | list[IO[bytes]] = [] if isinstance(source, (str, Path)): source = normalize_filepath(source, check_not_directory=False) - sources = [] - else: - sources = [ - normalize_filepath(source, check_not_directory=False) for source in source - ] + elif isinstance(source, list): + if is_path_or_str_sequence(source): + sources = [ + normalize_filepath(source, check_not_directory=False) + for source in source + ] + else: + sources = source + source = None # type: ignore[assignment] + if infer_schema_length == 0: msg = "'infer_schema_length' should be positive" raise ValueError(msg) @@ -266,8 +281,8 @@ def scan_ndjson( storage_options = None pylf = PyLazyFrame.new_from_ndjson( - path=source, - paths=sources, + source, + sources, infer_schema_length=infer_schema_length, schema=schema, schema_overrides=schema_overrides, diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index 90b6137c4924..bc434b05cc2d 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -3,26 +3,27 @@ import contextlib import io from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Sequence +from typing import IO, TYPE_CHECKING, Any import polars.functions as F +from polars import concat as plconcat from polars._utils.deprecation import deprecate_renamed_parameter from polars._utils.unstable import issue_unstable_warning from polars._utils.various import ( is_int_sequence, + is_path_or_str_sequence, normalize_filepath, ) -from polars._utils.wrap import wrap_df, wrap_ldf +from polars._utils.wrap import wrap_ldf from polars.convert import from_arrow from polars.dependencies import import_optional from polars.io._utils import ( - parse_columns_arg, parse_row_index_args, prepare_file_arg, ) with contextlib.suppress(ImportError): - from polars.polars import PyDataFrame, PyLazyFrame + from polars.polars import PyLazyFrame from polars.polars import read_parquet_schema as _read_parquet_schema if TYPE_CHECKING: @@ -33,7 +34,14 @@ @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_parquet( - source: str | Path | list[str] | list[Path] | IO[bytes] | bytes, + source: str + | Path + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[bytes]] + | list[bytes], *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -166,18 +174,11 @@ def read_parquet( ) # Read file and bytes inputs using `read_parquet` - elif isinstance(source, (io.IOBase, bytes)): - return _read_parquet_binary( - source, - columns=columns, - n_rows=n_rows, - parallel=parallel, - row_index_name=row_index_name, - row_index_offset=row_index_offset, - low_memory=low_memory, - use_statistics=use_statistics, - rechunk=rechunk, - ) + if isinstance(source, bytes): + source = io.BytesIO(source) + elif isinstance(source, list) and len(source) > 0 and isinstance(source[0], bytes): + assert all(isinstance(s, bytes) for s in source) + source = [io.BytesIO(s) for s in source] # type: ignore[arg-type, assignment] # For other inputs, defer to `scan_parquet` lf = scan_parquet( @@ -209,7 +210,14 @@ def read_parquet( def _read_parquet_with_pyarrow( - source: str | Path | list[str] | list[Path] | IO[bytes] | bytes, + source: str + | Path + | IO[bytes] + | bytes + | list[str] + | list[Path] + | list[IO[bytes]] + | list[bytes], *, columns: list[int] | list[str] | None = None, storage_options: dict[str, Any] | None = None, @@ -224,48 +232,35 @@ def _read_parquet_with_pyarrow( ) pyarrow_options = pyarrow_options or {} - with prepare_file_arg( - source, # type: ignore[arg-type] - use_pyarrow=True, - storage_options=storage_options, - ) as source_prep: - pa_table = pyarrow_parquet.read_table( - source_prep, - memory_map=memory_map, - columns=columns, - **pyarrow_options, - ) - return from_arrow(pa_table, rechunk=rechunk) # type: ignore[return-value] - + sources: list[str | Path | IO[bytes] | bytes | list[str] | list[Path]] = [] + if isinstance(source, list): + if len(source) > 0 and isinstance(source[0], (bytes, io.IOBase)): + sources = source # type: ignore[assignment] + else: + sources = [source] # type: ignore[list-item] + else: + sources = [source] -def _read_parquet_binary( - source: IO[bytes] | bytes, - *, - columns: Sequence[int] | Sequence[str] | None = None, - n_rows: int | None = None, - row_index_name: str | None = None, - row_index_offset: int = 0, - parallel: ParallelStrategy = "auto", - use_statistics: bool = True, - rechunk: bool = False, - low_memory: bool = False, -) -> DataFrame: - projection, columns = parse_columns_arg(columns) - row_index = parse_row_index_args(row_index_name, row_index_offset) + results: list[DataFrame] = [] + for source in sources: + with prepare_file_arg( + source, # type: ignore[arg-type] + use_pyarrow=True, + storage_options=storage_options, + ) as source_prep: + pa_table = pyarrow_parquet.read_table( + source_prep, + memory_map=memory_map, + columns=columns, + **pyarrow_options, + ) + result = from_arrow(pa_table, rechunk=rechunk) + results.append(result) # type: ignore[arg-type] - with prepare_file_arg(source) as source_prep: - pydf = PyDataFrame.read_parquet( - source_prep, - columns=columns, - projection=projection, - n_rows=n_rows, - row_index=row_index, - parallel=parallel, - use_statistics=use_statistics, - rechunk=rechunk, - low_memory=low_memory, - ) - return wrap_df(pydf) + if len(results) == 1: + return results[0] + else: + return plconcat(results) def read_parquet_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, DataType]: @@ -295,7 +290,7 @@ def read_parquet_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, Dat @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def scan_parquet( - source: str | Path | list[str] | list[Path], + source: str | Path | IO[bytes] | list[str] | list[Path] | list[IO[bytes]], *, n_rows: int | None = None, row_index_name: str | None = None, @@ -422,13 +417,13 @@ def scan_parquet( if isinstance(source, (str, Path)): source = normalize_filepath(source, check_not_directory=False) - else: + elif is_path_or_str_sequence(source): source = [ normalize_filepath(source, check_not_directory=False) for source in source ] return _scan_parquet_impl( - source, + source, # type: ignore[arg-type] n_rows=n_rows, cache=cache, parallel=parallel, @@ -448,7 +443,7 @@ def scan_parquet( def _scan_parquet_impl( - source: str | list[str] | list[Path], + source: str | list[str] | list[Path] | IO[str] | IO[bytes], *, n_rows: int | None = None, cache: bool = True, diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 36e0ef9a462d..1fd25cc1417a 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -50,6 +50,7 @@ def read_excel( engine: ExcelSpreadsheetEngine = ..., engine_options: dict[str, Any] | None = ..., read_options: dict[str, Any] | None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -65,6 +66,7 @@ def read_excel( sheet_name: None = ..., engine: ExcelSpreadsheetEngine = ..., engine_options: dict[str, Any] | None = ..., + has_header: bool = ..., read_options: dict[str, Any] | None = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., @@ -82,6 +84,7 @@ def read_excel( engine: ExcelSpreadsheetEngine = ..., engine_options: dict[str, Any] | None = ..., read_options: dict[str, Any] | None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -100,6 +103,7 @@ def read_excel( engine: ExcelSpreadsheetEngine = ..., engine_options: dict[str, Any] | None = ..., read_options: dict[str, Any] | None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -116,6 +120,7 @@ def read_excel( engine: ExcelSpreadsheetEngine = ..., engine_options: dict[str, Any] | None = ..., read_options: dict[str, Any] | None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -132,6 +137,7 @@ def read_excel( engine: ExcelSpreadsheetEngine = ..., engine_options: dict[str, Any] | None = ..., read_options: dict[str, Any] | None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -149,6 +155,7 @@ def read_excel( engine: ExcelSpreadsheetEngine = "calamine", engine_options: dict[str, Any] | None = None, read_options: dict[str, Any] | None = None, + has_header: bool = True, columns: Sequence[int] | Sequence[str] | None = None, schema_overrides: SchemaDict | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, @@ -207,6 +214,10 @@ def read_excel( * "calamine": `ExcelReader.load_sheet_by_name` * "xlsx2csv": `pl.read_csv` * "openpyxl": n/a (can only provide `engine_options`) + has_header + Indicate if the first row of the table data is a header or not. If False, + column names will be autogenerated in the following format: `column_x`, with + `x` being an enumeration over every column in the dataset, starting at 1. columns Columns to read from the sheet; if not specified, all columns are read. Can be given as a sequence of column names or indices. @@ -285,6 +296,7 @@ def read_excel( schema_overrides=schema_overrides, infer_schema_length=infer_schema_length, raise_if_empty=raise_if_empty, + has_header=has_header, columns=columns, ) @@ -295,6 +307,7 @@ def read_ods( *, sheet_id: None = ..., sheet_name: str, + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -308,6 +321,7 @@ def read_ods( *, sheet_id: None = ..., sheet_name: None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -321,6 +335,7 @@ def read_ods( *, sheet_id: int, sheet_name: str, + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -334,6 +349,7 @@ def read_ods( *, sheet_id: Literal[0] | Sequence[int], sheet_name: None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -347,6 +363,7 @@ def read_ods( *, sheet_id: int, sheet_name: None = ..., + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -360,6 +377,7 @@ def read_ods( *, sheet_id: None, sheet_name: list[str] | tuple[str], + has_header: bool = ..., columns: Sequence[int] | Sequence[str] | None = ..., schema_overrides: SchemaDict | None = ..., infer_schema_length: int | None = ..., @@ -372,6 +390,7 @@ def read_ods( *, sheet_id: int | Sequence[int] | None = None, sheet_name: str | list[str] | tuple[str] | None = None, + has_header: bool = True, columns: Sequence[int] | Sequence[str] | None = None, schema_overrides: SchemaDict | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, @@ -396,6 +415,10 @@ def read_ods( sheet_name Sheet name(s) to convert; cannot be used in conjunction with `sheet_id`. If more than one is given then a `{sheetname:frame,}` dict is returned. + has_header + Indicate if the first row of the table data is a header or not. If False, + column names will be autogenerated in the following format: `column_x`, with + `x` being an enumeration over every column in the dataset, starting at 1. columns Columns to read from the sheet; if not specified, all columns are read. Can be given as a sequence of column names or indices. @@ -446,6 +469,7 @@ def read_ods( schema_overrides=schema_overrides, infer_schema_length=infer_schema_length, raise_if_empty=raise_if_empty, + has_header=has_header, columns=columns, ) @@ -495,6 +519,7 @@ def _identify_workbook(wb: str | Path | IO[bytes] | bytes) -> str | None: def _read_spreadsheet( sheet_id: int | Sequence[int] | None, sheet_name: str | list[str] | tuple[str] | None, + *, source: str | Path | IO[bytes] | bytes, engine: ExcelSpreadsheetEngine, engine_options: dict[str, Any] | None = None, @@ -502,7 +527,7 @@ def _read_spreadsheet( schema_overrides: SchemaDict | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, columns: Sequence[int] | Sequence[str] | None = None, - *, + has_header: bool = True, raise_if_empty: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: if isinstance(source, (str, Path)): @@ -510,37 +535,16 @@ def _read_spreadsheet( if looks_like_url(source): source = process_file_url(source) - read_options = (read_options or {}).copy() + read_options = _get_read_options( + read_options, + engine=engine, + columns=columns, + has_header=has_header, + infer_schema_length=infer_schema_length, + ) engine_options = (engine_options or {}).copy() schema_overrides = dict(schema_overrides or {}) - # normalise some top-level parameters to 'read_options' entries - if engine == "calamine": - if ("use_columns" in read_options) and columns: - msg = 'cannot specify both `columns` and `read_options["use_columns"]`' - raise ParameterCollisionError(msg) - elif ("schema_sample_rows" in read_options) and ( - infer_schema_length != N_INFER_DEFAULT - ): - msg = 'cannot specify both `infer_schema_length` and `read_options["schema_sample_rows"]`' - raise ParameterCollisionError(msg) - - read_options["schema_sample_rows"] = infer_schema_length - - elif engine == "xlsx2csv": - if ("columns" in read_options) and columns: - msg = 'cannot specify both `columns` and `read_options["columns"]`' - raise ParameterCollisionError(msg) - elif ("infer_schema_length" in read_options) and ( - infer_schema_length != N_INFER_DEFAULT - ): - msg = 'cannot specify both `infer_schema_length` and `read_options["infer_schema_length"]`' - raise ParameterCollisionError(msg) - - read_options["infer_schema_length"] = infer_schema_length - else: - read_options["infer_schema_length"] = infer_schema_length - # establish the reading function, parser, and available worksheets reader_fn, parser, worksheets = _initialise_spreadsheet_parser( engine, source, engine_options @@ -573,6 +577,59 @@ def _read_spreadsheet( return next(iter(parsed_sheets.values())) +def _get_read_options( + read_options: dict[str, Any] | None, + *, + engine: ExcelSpreadsheetEngine, + columns: Sequence[int] | Sequence[str] | None, + infer_schema_length: int | None, + has_header: bool, +) -> dict[str, Any]: + """Normalise top-level parameters to engine-specific 'read_options' dict.""" + read_options = (read_options or {}).copy() + if engine == "calamine": + if ("use_columns" in read_options) and columns: + msg = 'cannot specify both `columns` and `read_options["use_columns"]`' + raise ParameterCollisionError(msg) + elif read_options.get("header_row") is not None and has_header is False: + msg = 'the values of `has_header` and `read_options["header_row"]` are not compatible' + raise ParameterCollisionError(msg) + elif ("schema_sample_rows" in read_options) and ( + infer_schema_length != N_INFER_DEFAULT + ): + msg = 'cannot specify both `infer_schema_length` and `read_options["schema_sample_rows"]`' + raise ParameterCollisionError(msg) + + read_options["schema_sample_rows"] = infer_schema_length + if has_header is False and "header_row" not in read_options: + read_options["header_row"] = None + + elif engine == "xlsx2csv": + if ("columns" in read_options) and columns: + msg = 'cannot specify both `columns` and `read_options["columns"]`' + raise ParameterCollisionError(msg) + elif ( + "has_header" in read_options + and read_options["has_header"] is not has_header + ): + msg = 'the values of `has_header` and `read_options["has_header"]` are not compatible' + raise ParameterCollisionError(msg) + elif ("infer_schema_length" in read_options) and ( + infer_schema_length != N_INFER_DEFAULT + ): + msg = 'cannot specify both `infer_schema_length` and `read_options["infer_schema_length"]`' + raise ParameterCollisionError(msg) + + read_options["infer_schema_length"] = infer_schema_length + if "has_header" not in read_options: + read_options["has_header"] = has_header + else: + read_options["infer_schema_length"] = infer_schema_length + read_options["has_header"] = has_header + + return read_options + + def _get_sheet_names( sheet_id: int | Sequence[int] | None, sheet_name: str | list[str] | tuple[str] | None, @@ -695,13 +752,7 @@ def _csv_buffer_to_frame( """Translate StringIO buffer containing delimited data as a DataFrame.""" # handle (completely) empty sheet data if csv.tell() == 0: - if raise_if_empty: - msg = ( - "empty Excel sheet" - "\n\nIf you want to read this as an empty DataFrame, set `raise_if_empty=False`." - ) - raise NoDataError(msg) - return pl.DataFrame() + return _empty_frame(raise_if_empty) if read_options is None: read_options = {} @@ -754,18 +805,21 @@ def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: df = df.drop(*null_cols) if len(df) == 0 and len(df.columns) == 0: - if not raise_if_empty: - return df - else: - msg = ( - "empty Excel sheet" - "\n\nIf you want to read this as an empty DataFrame, set `raise_if_empty=False`." - ) - raise NoDataError(msg) + return _empty_frame(raise_if_empty) return df.filter(~F.all_horizontal(F.all().is_null())) +def _empty_frame(raise_if_empty: bool) -> pl.DataFrame: # noqa: FBT001 + if raise_if_empty: + msg = ( + "empty Excel sheet" + "\n\nIf you want to read this as an empty DataFrame, set `raise_if_empty=False`." + ) + raise NoDataError(msg) + return pl.DataFrame() + + def _reorder_columns( df: pl.DataFrame, columns: Sequence[int] | Sequence[str] | None ) -> pl.DataFrame: @@ -788,6 +842,7 @@ def _read_spreadsheet_openpyxl( ) -> pl.DataFrame: """Use the 'openpyxl' library to read data from the given worksheet.""" infer_schema_length = read_options.pop("infer_schema_length", None) + has_header = read_options.pop("has_header", True) no_inference = infer_schema_length == 0 ws = parser[sheet_name] @@ -797,17 +852,28 @@ def _read_spreadsheet_openpyxl( if tables := getattr(ws, "tables", None): table = next(iter(tables.values())) rows = list(ws[table.ref]) - header.extend(cell.value for cell in rows.pop(0)) + if not rows: + return _empty_frame(raise_if_empty) + if has_header: + header.extend(cell.value for cell in rows.pop(0)) + else: + header.extend(f"column_{n}" for n in range(1, len(rows[0]) + 1)) if table.totalsRowCount: rows = rows[: -table.totalsRowCount] - rows_iter = iter(rows) + rows_iter = rows else: - rows_iter = ws.iter_rows() - for row in rows_iter: - row_values = [cell.value for cell in row] - if any(v is not None for v in row_values): - header.extend(row_values) - break + if not has_header: + if not (rows_iter := list(ws.iter_rows())): + return _empty_frame(raise_if_empty) + n_cols = len(rows_iter[0]) + header = [f"column_{n}" for n in range(1, n_cols + 1)] + else: + rows_iter = ws.iter_rows() + for row in rows_iter: + row_values = [cell.value for cell in row] + if any(v is not None for v in row_values): + header.extend(row_values) + break dtype = String if no_inference else None series_data = [] @@ -815,8 +881,8 @@ def _read_spreadsheet_openpyxl( if name: values = [cell.value for cell in column_data] if no_inference or (dtype := (schema_overrides or {}).get(name)) == String: # type: ignore[assignment] - # note: if we init series with mixed-type data (eg: str/int) - # the non-strings will become null, so we handle the cast here + # note: if we initialise the series with mixed-type data (eg: str/int) + # then the non-strings will become null, so we handle the cast here values = [str(v) if (v is not None) else v for v in values] s = pl.Series(name, values, dtype=dtype, strict=False) @@ -889,6 +955,10 @@ def _read_spreadsheet_calamine( else: ws_arrow = parser.load_sheet_eager(sheet_name, **read_options) df = from_arrow(ws_arrow) + if read_options.get("header_row", False) is None and not read_options.get( + "column_names" + ): + df.columns = [f"column_{i}" for i in range(1, len(df.columns) + 1)] # note: even if we applied parser dtypes we still re-apply schema_overrides # natively as we can refine integer/float types, temporal precision, etc. diff --git a/py-polars/polars/lazyframe/engine_config.py b/py-polars/polars/lazyframe/engine_config.py index 8dd75ebc48b6..ee6c2f8b7941 100644 --- a/py-polars/polars/lazyframe/engine_config.py +++ b/py-polars/polars/lazyframe/engine_config.py @@ -18,7 +18,7 @@ class GPUEngine: - `device`: Select the device to run the query on. - `memory_resource`: Set an RMM memory resource for - device-side allocations. + device-side allocations. """ device: int | None diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index fffa06b78369..44978528e272 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2,6 +2,7 @@ import contextlib import os +import warnings from datetime import date, datetime, time, timedelta from functools import lru_cache, partial, reduce from io import BytesIO, StringIO @@ -41,6 +42,7 @@ _in_notebook, _is_generator, extend_bool, + find_stacklevel, is_bool_sequence, is_sequence, issue_warning, @@ -82,7 +84,7 @@ from polars.lazyframe.group_by import LazyGroupBy from polars.lazyframe.in_process import InProcessQuery from polars.schema import Schema -from polars.selectors import _expand_selectors, by_dtype, expand_selector +from polars.selectors import by_dtype, expand_selector with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyLazyFrame @@ -680,7 +682,7 @@ def serialize( The format in which to serialize. Options: - `"binary"`: Serialize to binary format (bytes). This is the default. - - `"json"`: Serialize to JSON format (string). + - `"json"`: Serialize to JSON format (string) (deprecated). See Also -------- @@ -716,6 +718,11 @@ def serialize( if format == "binary": serializer = self._ldf.serialize_binary elif format == "json": + msg = "'json' serialization format of LazyFrame is deprecated" + warnings.warn( + msg, + stacklevel=find_stacklevel(), + ) serializer = self._ldf.serialize_json else: msg = f"`format` must be one of {{'binary', 'json'}}, got {format!r}" @@ -2841,7 +2848,7 @@ def lazy(self) -> LazyFrame: Return lazy representation, i.e. itself. Useful for writing code that expects either a :class:`DataFrame` or - :class:`LazyFrame`. + :class:`LazyFrame`. On LazyFrame this is a no-op, and returns the same object. Returns ------- @@ -3816,7 +3823,7 @@ def group_by_dynamic( │ 2021-12-16 03:00:00 ┆ 6 │ └─────────────────────┴─────┘ - Group by windows of 1 hour starting at 2021-12-16 00:00:00. + Group by windows of 1 hour. >>> lf.group_by_dynamic("time", every="1h", closed="right").agg( ... pl.col("n") @@ -3993,7 +4000,7 @@ def join_asof( tolerance: str | int | float | timedelta | None = None, allow_parallel: bool = True, force_parallel: bool = False, - coalesce: bool | None = None, + coalesce: bool = True, ) -> LazyFrame: """ Perform an asof join. @@ -4071,53 +4078,214 @@ def join_asof( Force the physical plan to evaluate the computation of both DataFrames up to the join in parallel. coalesce - Coalescing behavior (merging of join columns). + Coalescing behavior (merging of `on` / `left_on` / `right_on` columns): - - None: -> join specific. - True: -> Always coalesce join columns. - False: -> Never coalesce join columns. Note that joining on any other expressions than `col` will turn off coalescing. - Examples -------- - >>> from datetime import datetime + >>> from datetime import date >>> gdp = pl.LazyFrame( ... { - ... "date": [ - ... datetime(2016, 1, 1), - ... datetime(2017, 1, 1), - ... datetime(2018, 1, 1), - ... datetime(2019, 1, 1), - ... ], # note record date: Jan 1st (sorted!) - ... "gdp": [4164, 4411, 4566, 4696], + ... "date": pl.date_range( + ... date(2016, 1, 1), + ... date(2020, 1, 1), + ... "1y", + ... eager=True, + ... ), + ... "gdp": [4164, 4411, 4566, 4696, 4827], ... } - ... ).set_sorted("date") + ... ) + >>> gdp.collect() + shape: (5, 2) + ┌────────────┬──────┐ + │ date ┆ gdp │ + │ --- ┆ --- │ + │ date ┆ i64 │ + ╞════════════╪══════╡ + │ 2016-01-01 ┆ 4164 │ + │ 2017-01-01 ┆ 4411 │ + │ 2018-01-01 ┆ 4566 │ + │ 2019-01-01 ┆ 4696 │ + │ 2020-01-01 ┆ 4827 │ + └────────────┴──────┘ + >>> population = pl.LazyFrame( ... { - ... "date": [ - ... datetime(2016, 5, 12), - ... datetime(2017, 5, 12), - ... datetime(2018, 5, 12), - ... datetime(2019, 5, 12), - ... ], # note record date: May 12th (sorted!) - ... "population": [82.19, 82.66, 83.12, 83.52], + ... "date": [date(2016, 3, 1), date(2018, 8, 1), date(2019, 1, 1)], + ... "population": [82.19, 82.66, 83.12], ... } - ... ).set_sorted("date") + ... ).sort("date") + >>> population.collect() + shape: (3, 2) + ┌────────────┬────────────┐ + │ date ┆ population │ + │ --- ┆ --- │ + │ date ┆ f64 │ + ╞════════════╪════════════╡ + │ 2016-03-01 ┆ 82.19 │ + │ 2018-08-01 ┆ 82.66 │ + │ 2019-01-01 ┆ 83.12 │ + └────────────┴────────────┘ + + Note how the dates don't quite match. If we join them using `join_asof` and + `strategy='backward'`, then each date from `population` which doesn't have an + exact match is matched with the closest earlier date from `gdp`: + >>> population.join_asof(gdp, on="date", strategy="backward").collect() - shape: (4, 3) - ┌─────────────────────┬────────────┬──────┐ - │ date ┆ population ┆ gdp │ - │ --- ┆ --- ┆ --- │ - │ datetime[μs] ┆ f64 ┆ i64 │ - ╞═════════════════════╪════════════╪══════╡ - │ 2016-05-12 00:00:00 ┆ 82.19 ┆ 4164 │ - │ 2017-05-12 00:00:00 ┆ 82.66 ┆ 4411 │ - │ 2018-05-12 00:00:00 ┆ 83.12 ┆ 4566 │ - │ 2019-05-12 00:00:00 ┆ 83.52 ┆ 4696 │ - └─────────────────────┴────────────┴──────┘ + shape: (3, 3) + ┌────────────┬────────────┬──────┐ + │ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ i64 │ + ╞════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 4164 │ + │ 2018-08-01 ┆ 82.66 ┆ 4566 │ + │ 2019-01-01 ┆ 83.12 ┆ 4696 │ + └────────────┴────────────┴──────┘ + + Note how: + + - date `2016-03-01` from `population` is matched with `2016-01-01` from `gdp`; + - date `2018-08-01` from `population` is matched with `2018-01-01` from `gdp`. + + You can verify this by passing `coalesce=False`: + + >>> population.join_asof( + ... gdp, on="date", strategy="backward", coalesce=False + ... ).collect() + shape: (3, 4) + ┌────────────┬────────────┬────────────┬──────┐ + │ date ┆ population ┆ date_right ┆ gdp │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ date ┆ i64 │ + ╞════════════╪════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 2016-01-01 ┆ 4164 │ + │ 2018-08-01 ┆ 82.66 ┆ 2018-01-01 ┆ 4566 │ + │ 2019-01-01 ┆ 83.12 ┆ 2019-01-01 ┆ 4696 │ + └────────────┴────────────┴────────────┴──────┘ + + If we instead use `strategy='forward'`, then each date from `population` which + doesn't have an exact match is matched with the closest later date from `gdp`: + + >>> population.join_asof(gdp, on="date", strategy="forward").collect() + shape: (3, 3) + ┌────────────┬────────────┬──────┐ + │ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ i64 │ + ╞════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 4411 │ + │ 2018-08-01 ┆ 82.66 ┆ 4696 │ + │ 2019-01-01 ┆ 83.12 ┆ 4696 │ + └────────────┴────────────┴──────┘ + + Note how: + + - date `2016-03-01` from `population` is matched with `2017-01-01` from `gdp`; + - date `2018-08-01` from `population` is matched with `2019-01-01` from `gdp`. + + Finally, `strategy='nearest'` gives us a mix of the two results above, as each + date from `population` which doesn't have an exact match is matched with the + closest date from `gdp`, regardless of whether it's earlier or later: + + >>> population.join_asof(gdp, on="date", strategy="nearest").collect() + shape: (3, 3) + ┌────────────┬────────────┬──────┐ + │ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ i64 │ + ╞════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 4164 │ + │ 2018-08-01 ┆ 82.66 ┆ 4696 │ + │ 2019-01-01 ┆ 83.12 ┆ 4696 │ + └────────────┴────────────┴──────┘ + + Note how: + + - date `2016-03-01` from `population` is matched with `2016-01-01` from `gdp`; + - date `2018-08-01` from `population` is matched with `2019-01-01` from `gdp`. + + They `by` argument allows joining on another column first, before the asof join. + In this example we join by `country` first, then asof join by date, as above. + + >>> gdp_dates = pl.date_range( # fmt: skip + ... date(2016, 1, 1), date(2020, 1, 1), "1y", eager=True + ... ) + >>> gdp2 = pl.LazyFrame( + ... { + ... "country": ["Germany"] * 5 + ["Netherlands"] * 5, + ... "date": pl.concat([gdp_dates, gdp_dates]), + ... "gdp": [4164, 4411, 4566, 4696, 4827, 784, 833, 914, 910, 909], + ... } + ... ).sort("country", "date") + >>> + >>> gdp2.collect() + shape: (10, 3) + ┌─────────────┬────────────┬──────┐ + │ country ┆ date ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ str ┆ date ┆ i64 │ + ╞═════════════╪════════════╪══════╡ + │ Germany ┆ 2016-01-01 ┆ 4164 │ + │ Germany ┆ 2017-01-01 ┆ 4411 │ + │ Germany ┆ 2018-01-01 ┆ 4566 │ + │ Germany ┆ 2019-01-01 ┆ 4696 │ + │ Germany ┆ 2020-01-01 ┆ 4827 │ + │ Netherlands ┆ 2016-01-01 ┆ 784 │ + │ Netherlands ┆ 2017-01-01 ┆ 833 │ + │ Netherlands ┆ 2018-01-01 ┆ 914 │ + │ Netherlands ┆ 2019-01-01 ┆ 910 │ + │ Netherlands ┆ 2020-01-01 ┆ 909 │ + └─────────────┴────────────┴──────┘ + >>> pop2 = pl.LazyFrame( + ... { + ... "country": ["Germany"] * 3 + ["Netherlands"] * 3, + ... "date": [ + ... date(2016, 3, 1), + ... date(2018, 8, 1), + ... date(2019, 1, 1), + ... date(2016, 3, 1), + ... date(2018, 8, 1), + ... date(2019, 1, 1), + ... ], + ... "population": [82.19, 82.66, 83.12, 17.11, 17.32, 17.40], + ... } + ... ).sort("country", "date") + >>> + >>> pop2.collect() + shape: (6, 3) + ┌─────────────┬────────────┬────────────┐ + │ country ┆ date ┆ population │ + │ --- ┆ --- ┆ --- │ + │ str ┆ date ┆ f64 │ + ╞═════════════╪════════════╪════════════╡ + │ Germany ┆ 2016-03-01 ┆ 82.19 │ + │ Germany ┆ 2018-08-01 ┆ 82.66 │ + │ Germany ┆ 2019-01-01 ┆ 83.12 │ + │ Netherlands ┆ 2016-03-01 ┆ 17.11 │ + │ Netherlands ┆ 2018-08-01 ┆ 17.32 │ + │ Netherlands ┆ 2019-01-01 ┆ 17.4 │ + └─────────────┴────────────┴────────────┘ + >>> pop2.join_asof(gdp2, by="country", on="date", strategy="nearest").collect() + shape: (6, 4) + ┌─────────────┬────────────┬────────────┬──────┐ + │ country ┆ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ date ┆ f64 ┆ i64 │ + ╞═════════════╪════════════╪════════════╪══════╡ + │ Germany ┆ 2016-03-01 ┆ 82.19 ┆ 4164 │ + │ Germany ┆ 2018-08-01 ┆ 82.66 ┆ 4696 │ + │ Germany ┆ 2019-01-01 ┆ 83.12 ┆ 4696 │ + │ Netherlands ┆ 2016-03-01 ┆ 17.11 ┆ 784 │ + │ Netherlands ┆ 2018-08-01 ┆ 17.32 ┆ 910 │ + │ Netherlands ┆ 2019-01-01 ┆ 17.4 ┆ 910 │ + └─────────────┴────────────┴────────────┴──────┘ + """ if not isinstance(other, LazyFrame): msg = f"expected `other` join table to be a LazyFrame, not a {type(other).__name__!r}" @@ -4393,6 +4561,94 @@ def join( ) ) + @unstable() + def join_where( + self, + other: LazyFrame, + *predicates: Expr | Iterable[Expr], + suffix: str = "_right", + ) -> LazyFrame: + """ + Perform a join based on one or multiple (in)equality predicates. + + A row from this table may be included in zero or multiple rows in the result, + and the relative order of rows may differ between the input and output tables. + + .. warning:: + This functionality is experimental. It may be + changed at any point without it being considered a breaking change. + + Parameters + ---------- + other + LazyFrame to join with. + *predicates + (In)Equality condition to join the two table on. + The left `pl.col(..)` will refer to the left table + and the right `pl.col(..)` + to the right table. + For example: `pl.col("time") >= pl.col("duration")` + suffix + Suffix to append to columns with a duplicate name. + + Notes + ----- + This method is strict about its equality expressions. + Only 1 equality expression is allowed per predicate, where + the lhs `pl.col` refers to the left table in the join, and the + rhs `pl.col` refers to the right table. + + Examples + -------- + >>> east = pl.LazyFrame( + ... { + ... "id": [100, 101, 102], + ... "dur": [120, 140, 160], + ... "rev": [12, 14, 16], + ... "cores": [2, 8, 4], + ... } + ... ) + >>> west = pl.LazyFrame( + ... { + ... "t_id": [404, 498, 676, 742], + ... "time": [90, 130, 150, 170], + ... "cost": [9, 13, 15, 16], + ... "cores": [4, 2, 1, 4], + ... } + ... ) + >>> east.join_where( + ... west, + ... pl.col("dur") < pl.col("time"), + ... pl.col("rev") < pl.col("cost"), + ... ).collect() + shape: (5, 8) + ┌─────┬─────┬─────┬───────┬──────┬──────┬──────┬─────────────┐ + │ id ┆ dur ┆ rev ┆ cores ┆ t_id ┆ time ┆ cost ┆ cores_right │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═════╪═════╪═══════╪══════╪══════╪══════╪═════════════╡ + │ 100 ┆ 120 ┆ 12 ┆ 2 ┆ 498 ┆ 130 ┆ 13 ┆ 2 │ + │ 100 ┆ 120 ┆ 12 ┆ 2 ┆ 676 ┆ 150 ┆ 15 ┆ 1 │ + │ 100 ┆ 120 ┆ 12 ┆ 2 ┆ 742 ┆ 170 ┆ 16 ┆ 4 │ + │ 101 ┆ 140 ┆ 14 ┆ 8 ┆ 676 ┆ 150 ┆ 15 ┆ 1 │ + │ 101 ┆ 140 ┆ 14 ┆ 8 ┆ 742 ┆ 170 ┆ 16 ┆ 4 │ + └─────┴─────┴─────┴───────┴──────┴──────┴──────┴─────────────┘ + + """ + if not isinstance(other, LazyFrame): + msg = f"expected `other` join table to be a LazyFrame, not a {type(other).__name__!r}" + raise TypeError(msg) + + pyexprs = parse_into_list_of_expressions(*predicates) + + return self._from_pyldf( + self._ldf.join_where( + other._ldf, + pyexprs, + suffix, + ) + ) + def with_columns( self, *exprs: IntoExpr | Iterable[IntoExpr], @@ -5783,9 +6039,7 @@ def explode( │ c ┆ 8 │ └─────────┴─────────┘ """ - columns = parse_into_list_of_expressions( - *_expand_selectors(self, columns, *more_columns) - ) + columns = parse_into_list_of_expressions(columns, *more_columns) return self._from_pyldf(self._ldf.explode(columns)) def unique( @@ -5874,7 +6128,7 @@ def unique( └─────┴─────┴─────┘ """ if subset is not None: - subset = _expand_selectors(self, subset) + subset = parse_into_list_of_expressions(subset) return self._from_pyldf(self._ldf.unique(maintain_order, subset, keep)) def drop_nulls( @@ -5971,7 +6225,7 @@ def drop_nulls( └──────┴─────┴──────┘ """ if subset is not None: - subset = _expand_selectors(self, subset) + subset = parse_into_list_of_expressions(subset) return self._from_pyldf(self._ldf.drop_nulls(subset)) def unpivot( @@ -6005,9 +6259,7 @@ def unpivot( value_name Name to give to the `value` column. Defaults to "value" streamable - Allow this node to run in the streaming engine. - If this runs in streaming, the output of the unpivot operation - will not have a stable ordering. + deprecated Notes ----- @@ -6040,12 +6292,17 @@ def unpivot( │ z ┆ c ┆ 6 │ └─────┴──────────┴───────┘ """ - on = [] if on is None else _expand_selectors(self, on) - index = [] if index is None else _expand_selectors(self, index) + if not streamable: + issue_deprecation_warning( + "The `streamable` parameter for `LazyFrame.unpivot` is deprecated" + "This parameter has no effect", + version="1.5.0", + ) - return self._from_pyldf( - self._ldf.unpivot(on, index, value_name, variable_name, streamable) - ) + on = [] if on is None else parse_into_list_of_expressions(on) + index = [] if index is None else parse_into_list_of_expressions(index) + + return self._from_pyldf(self._ldf.unpivot(on, index, value_name, variable_name)) def map_batches( self, @@ -6224,7 +6481,7 @@ def unnest( │ bar ┆ 2 ┆ b ┆ null ┆ [3] ┆ womp │ └────────┴─────┴─────┴──────┴───────────┴───────┘ """ - columns = _expand_selectors(self, columns, *more_columns) + columns = parse_into_list_of_expressions(columns) return self._from_pyldf(self._ldf.unnest(columns)) def merge_sorted(self, other: LazyFrame, key: str) -> LazyFrame: diff --git a/py-polars/polars/meta/versions.py b/py-polars/polars/meta/versions.py index 6092fef10ecc..6788d25a68ea 100644 --- a/py-polars/polars/meta/versions.py +++ b/py-polars/polars/meta/versions.py @@ -20,13 +20,13 @@ def show_versions() -> None: Python: 3.11.8 (main, Feb 6 2024, 21:21:21) [Clang 15.0.0 (clang-1500.1.0.2.5)] ----Optional dependencies---- adbc_driver_manager: 0.11.0 + altair: 5.4.0 cloudpickle: 3.0.0 connectorx: 0.3.2 deltalake: 0.17.1 fastexcel: 0.10.4 fsspec: 2023.12.2 gevent: 24.2.1 - hvplot: 0.9.2 matplotlib: 3.8.4 nest_asyncio: 1.6.0 numpy: 1.26.4 @@ -44,9 +44,9 @@ def show_versions() -> None: # module) as a micro-optimization for polars' initial import import platform - deps = _get_dependency_info() + deps = _get_dependency_list() core_properties = ("Polars", "Index type", "Platform", "Python") - keylen = max(len(x) for x in [*core_properties, *deps.keys()]) + 1 + keylen = max(len(x) for x in [*core_properties, *deps]) + 1 print("--------Version info---------") print(f"{'Polars:':{keylen}s} {get_polars_version()}") @@ -55,14 +55,16 @@ def show_versions() -> None: print(f"{'Python:':{keylen}s} {sys.version}") print("\n----Optional dependencies----") - for name, v in deps.items(): - print(f"{name:{keylen}s} {v}") + for name in deps: + print(f"{name:{keylen}s} ", end="", flush=True) + print(_get_dependency_version(name)) -def _get_dependency_info() -> dict[str, str]: - # see the list of dependencies in pyproject.toml - opt_deps = [ +# See the list of dependencies in pyproject.toml. +def _get_dependency_list() -> list[str]: + return [ "adbc_driver_manager", + "altair", "cloudpickle", "connectorx", "deltalake", @@ -70,7 +72,6 @@ def _get_dependency_info() -> dict[str, str]: "fsspec", "gevent", "great_tables", - "hvplot", "matplotlib", "nest_asyncio", "numpy", @@ -84,7 +85,6 @@ def _get_dependency_info() -> dict[str, str]: "xlsx2csv", "xlsxwriter", ] - return {f"{name}:": _get_dependency_version(name) for name in opt_deps} def _get_dependency_version(dep_name: str) -> str: diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 718ffec75b93..019d2d2f3ad0 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -1,9 +1,13 @@ from __future__ import annotations from collections import OrderedDict -from typing import TYPE_CHECKING, Iterable, Mapping +from collections.abc import Mapping +from typing import TYPE_CHECKING, Iterable + +from polars.datatypes._parse import parse_into_dtype if TYPE_CHECKING: + from polars._typing import PythonDataType from polars.datatypes import DataType BaseSchema = OrderedDict[str, DataType] @@ -49,10 +53,19 @@ class Schema(BaseSchema): def __init__( self, - schema: Mapping[str, DataType] | Iterable[tuple[str, DataType]] | None = None, + schema: ( + Mapping[str, DataType | PythonDataType] + | Iterable[tuple[str, DataType | PythonDataType]] + | None + ) = None, ): - schema = schema or {} - super().__init__(schema) + input = ( + schema.items() if schema and isinstance(schema, Mapping) else (schema or {}) + ) + super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc] + + def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None: + super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment] def names(self) -> list[str]: """Get the column names of the schema.""" @@ -65,3 +78,15 @@ def dtypes(self) -> list[DataType]: def len(self) -> int: """Get the number of columns in the schema.""" return len(self) + + def to_python(self) -> dict[str, type]: + """ + Return Schema as a dictionary of column names and their Python types. + + Examples + -------- + >>> s = pl.Schema({"x": pl.Int8(), "y": pl.String(), "z": pl.Duration("ms")}) + >>> s.to_python() + {'x': , 'y': , 'z': } + """ + return {name: tp.to_python() for name, tp in self.items()} diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index a42f346c6562..2e56aa3fb91e 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -2494,7 +2494,7 @@ def starts_with(*prefix: str) -> SelectorType: def string(*, include_categorical: bool = False) -> SelectorType: """ - Select all String (and, optionally, Categorical) string columns . + Select all String (and, optionally, Categorical) string columns. See Also -------- diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 8c8bfb32bad8..928c81410a62 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -1764,10 +1764,12 @@ def round(self, every: str | dt.timedelta | IntoExprColumn) -> Series: This functionality is considered **unstable**. It may be changed at any point without it being considered a breaking change. - Each date/datetime in the first half of the interval is mapped to the start of - its bucket. - Each date/datetime in the second half of the interval is mapped to the end of - its bucket. + - Each date/datetime in the first half of the interval + is mapped to the start of its bucket. + - Each date/datetime in the second half of the interval + is mapped to the end of its bucket. + - Half-way points are mapped to the start of their bucket. + Ambiguous results are localized using the DST offset of the original timestamp - for example, rounding `'2022-11-06 01:20:00 CST'` by `'1h'` results in `'2022-11-06 01:00:00 CST'`, whereas rounding `'2022-11-06 01:20:00 CDT'` by diff --git a/py-polars/polars/series/plotting.py b/py-polars/polars/series/plotting.py new file mode 100644 index 000000000000..cb5c6c93a1e1 --- /dev/null +++ b/py-polars/polars/series/plotting.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from polars.dependencies import altair as alt + +if TYPE_CHECKING: + import sys + + from altair.typing import EncodeKwds + + if sys.version_info >= (3, 11): + from typing import Unpack + else: + from typing_extensions import Unpack + + from polars import Series + + +class SeriesPlot: + """Series.plot namespace.""" + + _accessor = "plot" + + def __init__(self, s: Series) -> None: + name = s.name or "value" + self._df = s.to_frame(name) + self._series_name = name + + def hist( + self, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw histogram. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `s.plot.hist(**kwargs)` is shorthand for + `alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + **kwargs + Additional arguments and keyword arguments passed to Altair. + + Examples + -------- + >>> s = pl.Series("price", [1, 3, 3, 3, 5, 2, 6, 5, 5, 5, 7]) + >>> s.plot.hist() # doctest: +SKIP + """ # noqa: W505 + if self._series_name == "count()": + msg = "Cannot use `plot.hist` when Series name is `'count()'`" + raise ValueError(msg) + return ( + alt.Chart(self._df) + .mark_bar() + .encode(x=alt.X(f"{self._series_name}:Q", bin=True), y="count()", **kwargs) # type: ignore[misc] + .interactive() + ) + + def kde( + self, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw kernel density estimate plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `s.plot.kde(**kwargs)` is shorthand for + `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> s = pl.Series("price", [1, 3, 3, 3, 5, 2, 6, 5, 5, 5, 7]) + >>> s.plot.kde() # doctest: +SKIP + """ # noqa: W505 + if self._series_name == "density": + msg = "Cannot use `plot.kde` when Series name is `'density'`" + raise ValueError(msg) + return ( + alt.Chart(self._df) + .transform_density(self._series_name, as_=[self._series_name, "density"]) + .mark_area() + .encode(x=self._series_name, y="density:Q", **kwargs) # type: ignore[misc] + .interactive() + ) + + def line( + self, + /, + **kwargs: Unpack[EncodeKwds], + ) -> alt.Chart: + """ + Draw line plot. + + Polars does not implement plotting logic itself but instead defers to + `Altair `_. + + `s.plot.line(**kwargs)` is shorthand for + `alt.Chart(s.to_frame().with_row_index()).mark_line().encode(x='index', y=s.name, **kwargs).interactive()`, + and is provided for convenience - for full customisatibility, use a plotting + library directly. + + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + + Parameters + ---------- + **kwargs + Additional keyword arguments passed to Altair. + + Examples + -------- + >>> s = pl.Series("price", [1, 3, 3, 3, 5, 2, 6, 5, 5, 5, 7]) + >>> s.plot.kde() # doctest: +SKIP + """ # noqa: W505 + if self._series_name == "index": + msg = "Cannot call `plot.line` when Series name is 'index'" + raise ValueError(msg) + return ( + alt.Chart(self._df.with_row_index()) + .mark_line() + .encode(x="index", y=self._series_name, **kwargs) # type: ignore[misc] + .interactive() + ) + + def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: + if self._series_name == "index": + msg = "Cannot call `plot.{attr}` when Series name is 'index'" + raise ValueError(msg) + if attr == "scatter": + # alias `scatter` to `point` because of how common it is + attr = "point" + method = getattr(alt.Chart(self._df.with_row_index()), f"mark_{attr}", None) + if method is None: + msg = "Altair has no method 'mark_{attr}'" + raise AttributeError(msg) + return ( + lambda **kwargs: method() + .encode(x="index", y=self._series_name, **kwargs) + .interactive() + ) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a4c0d0b16045..df0f4ec469bf 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -86,12 +86,12 @@ ) from polars.datatypes._utils import dtype_to_init_repr from polars.dependencies import ( - _HVPLOT_AVAILABLE, + _ALTAIR_AVAILABLE, _PYARROW_AVAILABLE, _check_for_numpy, _check_for_pandas, _check_for_pyarrow, - hvplot, + altair, import_optional, ) from polars.dependencies import numpy as np @@ -104,6 +104,7 @@ from polars.series.categorical import CatNameSpace from polars.series.datetime import DateTimeNameSpace from polars.series.list import ListNameSpace +from polars.series.plotting import SeriesPlot from polars.series.string import StringNameSpace from polars.series.struct import StructNameSpace from polars.series.utils import expr_dispatch, get_ffi_func @@ -117,7 +118,6 @@ import jax import numpy.typing as npt import torch - from hvplot.plotting.core import hvPlotTabularPolars from polars import DataFrame, DataType, Expr from polars._typing import ( @@ -1250,7 +1250,31 @@ def __getitem__(self, key: MultiIndexSelector) -> Series: ... def __getitem__( self, key: SingleIndexSelector | MultiIndexSelector ) -> Any | Series: - """Get part of the Series as a new Series or scalar.""" + """ + Get part of the Series as a new Series or scalar. + + Parameters + ---------- + key + Row(s) to select. + + Returns + ------- + Series or scalar, depending on `key`. + + Examples + -------- + >>> s = pl.Series("a", [1, 4, 2]) + >>> s[0] + 1 + >>> s[0:2] + shape: (2,) + Series: 'a' [i64] + [ + 1 + 4 + ] + """ return get_series_item_by_key(self, key) def __setitem__( @@ -1359,7 +1383,10 @@ def __array_ufunc__( if isinstance(arg, (int, float, np.ndarray)): args.append(arg) elif isinstance(arg, Series): - args.append(arg.to_physical()._s.to_numpy_view()) + phys_arg = arg.to_physical() + if phys_arg._s.n_chunks() > 1: + phys_arg._s.rechunk(in_place=True) + args.append(phys_arg._s.to_numpy_view()) else: msg = f"unsupported type {type(arg).__name__!r} for {arg!r}" raise TypeError(msg) @@ -2422,7 +2449,7 @@ def hist( If None given, we determine the boundaries based on the data. bin_count If no bins provided, this will be used to determine - the distance of the bins + the distance of the bins. include_breakpoint Include a column that indicates the upper breakpoint. include_category @@ -2436,18 +2463,17 @@ def hist( -------- >>> a = pl.Series("a", [1, 3, 8, 8, 2, 1, 3]) >>> a.hist(bin_count=4) - shape: (5, 3) - ┌────────────┬─────────────┬───────┐ - │ breakpoint ┆ category ┆ count │ - │ --- ┆ --- ┆ --- │ - │ f64 ┆ cat ┆ u32 │ - ╞════════════╪═════════════╪═══════╡ - │ 0.0 ┆ (-inf, 0.0] ┆ 0 │ - │ 2.25 ┆ (0.0, 2.25] ┆ 3 │ - │ 4.5 ┆ (2.25, 4.5] ┆ 2 │ - │ 6.75 ┆ (4.5, 6.75] ┆ 0 │ - │ inf ┆ (6.75, inf] ┆ 2 │ - └────────────┴─────────────┴───────┘ + shape: (4, 3) + ┌────────────┬───────────────┬───────┐ + │ breakpoint ┆ category ┆ count │ + │ --- ┆ --- ┆ --- │ + │ f64 ┆ cat ┆ u32 │ + ╞════════════╪═══════════════╪═══════╡ + │ 2.75 ┆ (0.993, 2.75] ┆ 3 │ + │ 4.5 ┆ (2.75, 4.5] ┆ 2 │ + │ 6.25 ┆ (4.5, 6.25] ┆ 0 │ + │ 8.0 ┆ (6.25, 8.0] ┆ 2 │ + └────────────┴───────────────┴───────┘ """ out = ( self.to_frame() @@ -2557,7 +2583,7 @@ def unique_counts(self) -> Series: ] """ - def entropy(self, base: float = math.e, *, normalize: bool = False) -> float | None: + def entropy(self, base: float = math.e, *, normalize: bool = True) -> float | None: """ Computes the entropy. @@ -3946,7 +3972,7 @@ def equals( See Also -------- - assert_series_equal + polars.testing.assert_series_equal Examples -------- @@ -4986,27 +5012,28 @@ def mode(self) -> Series: def sign(self) -> Series: """ - Compute the element-wise indication of the sign. + Compute the element-wise sign function on numeric types. - The returned values can be -1, 0, or 1: + The returned value is computed as follows: - * -1 if x < 0. - * 0 if x == 0. - * 1 if x > 0. + * -1 if x < 0. + * 1 if x > 0. + * x otherwise (typically 0, but could be NaN if the input is). - (null values are preserved as-is). + Null values are preserved as-is, and the dtype of the input is preserved. Examples -------- - >>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None]) + >>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None]) >>> s.sign() - shape: (5,) - Series: 'a' [i64] + shape: (6,) + Series: 'a' [f64] [ - -1 - 0 - 0 - 1 + -1.0 + -0.0 + 0.0 + 1.0 + NaN null ] """ @@ -5327,7 +5354,9 @@ def map_elements( warn_on_inefficient_map(function, columns=[self.name], map_target="series") return self._from_pyseries( - self._s.apply_lambda(function, pl_return_dtype, skip_nulls) + self._s.map_elements( + function, return_dtype=pl_return_dtype, skip_nulls=skip_nulls + ) ) def shift(self, n: int = 1, *, fill_value: IntoExpr | None = None) -> Series: @@ -6874,7 +6903,7 @@ def ewm_mean( ignore_nulls: bool = False, ) -> Series: r""" - Exponentially-weighted moving average. + Compute exponentially-weighted moving average. Parameters ---------- @@ -6889,11 +6918,11 @@ def ewm_mean( .. math:: \alpha = \frac{2}{\theta + 1} \; \forall \; \theta \geq 1 half_life - Specify decay in terms of half-life, :math:`\lambda`, with + Specify decay in terms of half-life, :math:`\tau`, with .. math:: - \alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \lambda } \right\} \; - \forall \; \lambda > 0 + \alpha = 1 - \exp \left\{ \frac{ -\ln(2) }{ \tau } \right\} \; + \forall \; \tau > 0 alpha Specify smoothing factor alpha directly, :math:`0 < \alpha \leq 1`. adjust @@ -6949,20 +6978,21 @@ def ewm_mean_by( half_life: str | timedelta, ) -> Series: r""" - Calculate time-based exponentially weighted moving average. + Compute time-based exponentially weighted moving average. - Given observations :math:`x_1, x_2, \ldots, x_n` at times - :math:`t_1, t_2, \ldots, t_n`, the EWMA is calculated as + Given observations :math:`x_0, x_1, \ldots, x_{n-1}` at times + :math:`t_0, t_1, \ldots, t_{n-1}`, the EWMA is calculated as .. math:: y_0 &= x_0 - \alpha_i &= \exp(-\lambda(t_i - t_{i-1})) + \alpha_i &= 1 - \exp \left\{ \frac{ -\ln(2)(t_i-t_{i-1}) } + { \tau } \right\} y_i &= \alpha_i x_i + (1 - \alpha_i) y_{i-1}; \quad i > 0 - where :math:`\lambda` equals :math:`\ln(2) / \text{half_life}`. + where :math:`\tau` is the `half_life`. Parameters ---------- @@ -7038,7 +7068,7 @@ def ewm_std( ignore_nulls: bool = False, ) -> Series: r""" - Exponentially-weighted moving standard deviation. + Compute exponentially-weighted moving standard deviation. Parameters ---------- @@ -7122,7 +7152,7 @@ def ewm_var( ignore_nulls: bool = False, ) -> Series: r""" - Exponentially-weighted moving variance. + Compute exponentially-weighted moving variance. Parameters ---------- @@ -7377,7 +7407,7 @@ def struct(self) -> StructNameSpace: @property @unstable() - def plot(self) -> hvPlotTabularPolars: + def plot(self) -> SeriesPlot: """ Create a plot namespace. @@ -7385,33 +7415,44 @@ def plot(self) -> hvPlotTabularPolars: This functionality is currently considered **unstable**. It may be changed at any point without it being considered a breaking change. + .. versionchanged:: 1.6.0 + In prior versions of Polars, HvPlot was the plotting backend. If you would + like to restore the previous plotting functionality, all you need to do + is add `import hvplot.polars` at the top of your script and replace + `df.plot` with `df.hvplot`. + Polars does not implement plotting logic itself, but instead defers to - hvplot. Please see the `hvplot reference gallery `_ - for more information and documentation. + Altair: + + - `s.plot.hist(**kwargs)` + is shorthand for + `alt.Chart(s.to_frame()).mark_bar().encode(x=alt.X(f'{s.name}:Q', bin=True), y='count()', **kwargs).interactive()` + - `s.plot.kde(**kwargs)` + is shorthand for + `alt.Chart(s.to_frame()).transform_density(s.name, as_=[s.name, 'density']).mark_area().encode(x=s.name, y='density:Q', **kwargs).interactive()` + - for any other attribute `attr`, `s.plot.attr(**kwargs)` + is shorthand for + `alt.Chart(s.to_frame().with_row_index()).mark_attr().encode(x='index', y=s.name, **kwargs).interactive()` Examples -------- Histogram: - >>> s = pl.Series("values", [1, 4, 2]) + >>> s = pl.Series([1, 4, 4, 6, 2, 4, 3, 5, 5, 7, 1]) >>> s.plot.hist() # doctest: +SKIP - KDE plot (note: in addition to ``hvplot``, this one also requires ``scipy``): + KDE plot: >>> s.plot.kde() # doctest: +SKIP - For more info on what you can pass, you can use ``hvplot.help``: + Line plot: - >>> import hvplot # doctest: +SKIP - >>> hvplot.help("hist") # doctest: +SKIP - """ - if not _HVPLOT_AVAILABLE or parse_version(hvplot.__version__) < parse_version( - "0.9.1" - ): - msg = "hvplot>=0.9.1 is required for `.plot`" + >>> s.plot.line() # doctest: +SKIP + """ # noqa: W505 + if not _ALTAIR_AVAILABLE or parse_version(altair.__version__) < (5, 4, 0): + msg = "altair>=5.4.0 is required for `.plot`" raise ModuleUpgradeRequiredError(msg) - hvplot.post_patch() - return hvplot.plotting.core.hvPlotTabularPolars(self) + return SeriesPlot(self) def _resolve_temporal_dtype( diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 306188e36b92..953803fc1253 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Mapping from polars._utils.deprecation import deprecate_function from polars._utils.unstable import unstable +from polars._utils.various import no_default from polars.datatypes.constants import N_INFER_DEFAULT from polars.series.utils import expr_dispatch @@ -18,6 +19,7 @@ TimeUnit, TransferEncoding, ) + from polars._utils.various import NoDefault from polars.polars import PySeries @@ -376,7 +378,7 @@ def contains( self, pattern: str | Expr, *, literal: bool = False, strict: bool = True ) -> Series: """ - Check if strings in Series contain a substring that matches a regex. + Check if the string contains a substring that matches a pattern. Parameters ---------- @@ -480,7 +482,7 @@ def find( See Also -------- - contains : Check if string contains a substring that matches a regex. + contains : Check if the string contains a substring that matches a pattern. Examples -------- @@ -535,7 +537,7 @@ def ends_with(self, suffix: str | Expr) -> Series: See Also -------- - contains : Check if string contains a substring that matches a regex. + contains : Check if the string contains a substring that matches a pattern. starts_with : Check if string values start with a substring. Examples @@ -562,7 +564,7 @@ def starts_with(self, prefix: str | Expr) -> Series: See Also -------- - contains : Check if string contains a substring that matches a regex. + contains : Check if the string contains a substring that matches a pattern. ends_with : Check if string values end with a substring. Examples @@ -1481,7 +1483,7 @@ def zfill(self, length: int | IntoExprColumn) -> Series: def to_lowercase(self) -> Series: """ - Modify the strings to their lowercase equivalent. + Modify strings to their lowercase equivalent. Examples -------- @@ -1497,7 +1499,7 @@ def to_lowercase(self) -> Series: def to_uppercase(self) -> Series: """ - Modify the strings to their uppercase equivalent. + Modify strings to their uppercase equivalent. Examples -------- @@ -1513,17 +1515,31 @@ def to_uppercase(self) -> Series: def to_titlecase(self) -> Series: """ - Modify the strings to their titlecase equivalent. + Modify strings to their titlecase equivalent. + + Notes + ----- + This is a form of case transform where the first letter of each word is + capitalized, with the rest of the word in lowercase. Non-alphanumeric + characters define the word boundaries. Examples -------- - >>> s = pl.Series("sing", ["welcome to my world", "THERE'S NO TURNING BACK"]) + >>> s = pl.Series( + ... "quotes", + ... [ + ... "'e.t. phone home'", + ... "you talkin' to me?", + ... "to infinity,and BEYOND!", + ... ], + ... ) >>> s.str.to_titlecase() - shape: (2,) - Series: 'sing' [str] + shape: (3,) + Series: 'quotes' [str] [ - "Welcome To My World" - "There's No Turning Back" + "'E.T. Phone Home'" + "You Talkin' To Me?" + "To Infinity,And Beyond!" ] """ @@ -1804,9 +1820,9 @@ def contains_any( self, patterns: Series | list[str], *, ascii_case_insensitive: bool = False ) -> Series: """ - Use the aho-corasick algorithm to find matches. + Use the Aho-Corasick algorithm to find matches. - This version determines if any of the patterns find a match. + Determines if any of the patterns are contained in the string. Parameters ---------- @@ -1817,6 +1833,11 @@ def contains_any( When this option is enabled, searching will be performed without respect to case for ASCII letters (a-z and A-Z) only. + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + Examples -------- >>> _ = pl.Config.set_fmt_str_lengths(100) @@ -1840,28 +1861,39 @@ def contains_any( def replace_many( self, - patterns: Series | list[str], - replace_with: Series | list[str] | str, + patterns: Series | list[str] | Mapping[str, str], + replace_with: Series | list[str] | str | NoDefault = no_default, *, ascii_case_insensitive: bool = False, ) -> Series: """ - Use the aho-corasick algorithm to replace many matches. + Use the Aho-Corasick algorithm to replace many matches. Parameters ---------- patterns String patterns to search and replace. + Also accepts a mapping of patterns to their replacement as syntactic sugar + for `replace_many(pl.Series(mapping.keys()), pl.Series(mapping.values()))`. replace_with Strings to replace where a pattern was a match. - This can be broadcast, so it supports many:one and many:many. + Length must match the length of `patterns` or have length 1. This can be + broadcasted, so it supports many:one and many:many. ascii_case_insensitive Enable ASCII-aware case-insensitive matching. When this option is enabled, searching will be performed without respect to case for ASCII letters (a-z and A-Z) only. + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + Examples -------- + Replace many patterns by passing lists of equal length to the `patterns` and + `replace_with` parameters. + >>> _ = pl.Config.set_fmt_str_lengths(100) >>> s = pl.Series( ... "lyrics", @@ -1879,6 +1911,49 @@ def replace_many( "Tell you what me want, what me really really want" "Can me feel the love tonight" ] + + Broadcast a replacement for many patterns by passing a string or a sequence of + length 1 to the `replace_with` parameter. + + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> s = pl.Series( + ... "lyrics", + ... [ + ... "Everybody wants to rule the world", + ... "Tell me what you want, what you really really want", + ... "Can you feel the love tonight", + ... ], + ... ) + >>> s.str.replace_many(["me", "you", "they"], "") + shape: (3,) + Series: 'lyrics' [str] + [ + "Everybody wants to rule the world" + "Tell what want, what really really want" + "Can feel the love tonight" + ] + + Passing a mapping with patterns and replacements is also supported as syntactic + sugar. + + >>> _ = pl.Config.set_fmt_str_lengths(100) + >>> s = pl.Series( + ... "lyrics", + ... [ + ... "Everybody wants to rule the world", + ... "Tell me what you want, what you really really want", + ... "Can you feel the love tonight", + ... ], + ... ) + >>> mapping = {"me": "you", "you": "me", "want": "need"} + >>> s.str.replace_many(mapping) + shape: (3,) + Series: 'lyrics' [str] + [ + "Everybody needs to rule the world" + "Tell you what me need, what me really really need" + "Can me feel the love tonight" + ] """ @unstable() @@ -1890,7 +1965,7 @@ def extract_many( overlapping: bool = False, ) -> Series: """ - Use the aho-corasick algorithm to extract many matches. + Use the Aho-Corasick algorithm to extract many matches. Parameters ---------- @@ -1903,6 +1978,11 @@ def extract_many( overlapping Whether matches may overlap. + Notes + ----- + This method supports matching on string literals only, and does not support + regular expression matching. + Examples -------- >>> s = pl.Series("values", ["discontent"]) diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index 5d6112b6cb08..ca7348140fc9 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -70,18 +70,12 @@ def assert_frame_equal( >>> from polars.testing import assert_frame_equal >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) >>> df2 = pl.DataFrame({"a": [1, 5, 3]}) - >>> assert_frame_equal(df1, df2) # doctest: +SKIP + >>> assert_frame_equal(df1, df2) Traceback (most recent call last): ... - AssertionError: Series are different (value mismatch) + AssertionError: DataFrames are different (value mismatch for column 'a') [left]: [1, 2, 3] [right]: [1, 5, 3] - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - AssertionError: values for column 'a' are different """ __tracebackhide__ = True @@ -250,13 +244,14 @@ def assert_frame_not_equal( >>> from polars.testing import assert_frame_not_equal >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) >>> df2 = pl.DataFrame({"a": [1, 2, 3]}) - >>> assert_frame_not_equal(df1, df2) # doctest: +SKIP + >>> assert_frame_not_equal(df1, df2) Traceback (most recent call last): ... - AssertionError: frames are equal + AssertionError: DataFrames are equal (but are expected not to be) """ __tracebackhide__ = True + lazy = _assert_correct_input_type(left, right) try: assert_frame_equal( left=left, @@ -272,5 +267,6 @@ def assert_frame_not_equal( except AssertionError: return else: - msg = "frames are equal" + objects = "LazyFrames" if lazy else "DataFrames" + msg = f"{objects} are equal (but are expected not to be)" raise AssertionError(msg) diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index ad316f565aad..e85e90790bd4 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from polars._utils.deprecation import deprecate_renamed_parameter from polars.datatypes import ( @@ -20,6 +20,19 @@ from polars import DataType +def _assert_correct_input_type(left: Any, right: Any) -> bool: + __tracebackhide__ = True + + if not (isinstance(left, Series) and isinstance(right, Series)): + raise_assertion_error( + "inputs", + "unexpected input types", + type(left).__name__, + type(right).__name__, + ) + return True + + @deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31") def assert_series_equal( left: Series, @@ -81,22 +94,16 @@ def assert_series_equal( >>> from polars.testing import assert_series_equal >>> s1 = pl.Series([1, 2, 3]) >>> s2 = pl.Series([1, 5, 3]) - >>> assert_series_equal(s1, s2) # doctest: +SKIP + >>> assert_series_equal(s1, s2) Traceback (most recent call last): ... - AssertionError: Series are different (value mismatch) + AssertionError: Series are different (exact value mismatch) [left]: [1, 2, 3] [right]: [1, 5, 3] """ __tracebackhide__ = True - if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] - raise_assertion_error( - "inputs", - "unexpected input types", - type(left).__name__, - type(right).__name__, - ) + _assert_correct_input_type(left, right) if left.len() != right.len(): raise_assertion_error("Series", "length mismatch", left.len(), right.len()) @@ -397,13 +404,14 @@ def assert_series_not_equal( >>> from polars.testing import assert_series_not_equal >>> s1 = pl.Series([1, 2, 3]) >>> s2 = pl.Series([1, 2, 3]) - >>> assert_series_not_equal(s1, s2) # doctest: +SKIP + >>> assert_series_not_equal(s1, s2) Traceback (most recent call last): ... - AssertionError: Series are equal + AssertionError: Series are equal (but are expected not to be) """ __tracebackhide__ = True + _assert_correct_input_type(left, right) try: assert_series_equal( left=left, @@ -419,5 +427,5 @@ def assert_series_not_equal( except AssertionError: return else: - msg = "Series are equal" + msg = "Series are equal (but are expected not to be)" raise AssertionError(msg) diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 913f96768b2d..efa5bd1f15f9 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -71,7 +71,7 @@ iceberg = ["pyiceberg >= 0.5.0"] async = ["gevent"] cloudpickle = ["cloudpickle"] graph = ["matplotlib"] -plot = ["hvplot >= 0.9.1", "polars[pandas]"] +plot = ["altair >= 5.4.0"] style = ["great-tables >= 0.8.0"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] @@ -103,6 +103,7 @@ module = [ "IPython.*", "adbc_driver_manager.*", "adbc_driver_sqlite.*", + "altair.*", "arrow_odbc", "backports", "connectorx", @@ -110,7 +111,6 @@ module = [ "fsspec.*", "gevent", "great_tables", - "hvplot.*", "jax.*", "kuzu", "matplotlib.*", @@ -237,6 +237,7 @@ markers = [ "release: Tests that should be run on a Polars release build.", "slow: Tests with a longer than average runtime.", "write_disk: Tests that write to disk", + "may_fail_auto_streaming: Test that may fail when automatically using the streaming engine for all lazy queries.", ] filterwarnings = [ # Fail on warnings diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 50192c9653a1..9d0a2df292cb 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -17,7 +17,7 @@ pip # Interop numpy -numba; python_version < '3.13' # Numba can lag Python releases +numba >= 0.54; python_version < '3.13' # Numba can lag Python releases pandas pyarrow pydantic>=2.0.0 @@ -47,7 +47,7 @@ deltalake>=0.15.0 # Csv zstandard # Plotting -hvplot>=0.9.1 +altair>=5.4.0 # Styling great-tables>=0.8.0; python_version >= '3.9' # Async diff --git a/py-polars/requirements-lint.txt b/py-polars/requirements-lint.txt index 4119d9a23fb3..28c5e645e0d3 100644 --- a/py-polars/requirements-lint.txt +++ b/py-polars/requirements-lint.txt @@ -1,3 +1,3 @@ mypy==1.11.1 -ruff==0.5.0 -typos==1.23.5 +ruff==0.6.3 +typos==1.24.2 diff --git a/py-polars/src/lazyframe/visitor/mod.rs b/py-polars/src/lazyframe/visitor/mod.rs deleted file mode 100644 index 674049b9bb42..000000000000 --- a/py-polars/src/lazyframe/visitor/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod expr_nodes; -pub(crate) mod nodes; diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 9647a0ac33ce..484583f85150 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(vec_into_raw_parts)] #![allow(clippy::nonstandard_macro_braces)] // Needed because clippy does not understand proc macro of PyO3 #![allow(clippy::transmute_undefined_repr)] #![allow(non_local_definitions)] @@ -11,55 +10,29 @@ mod build { } mod allocator; -#[cfg(feature = "csv")] -mod batched_csv; -#[cfg(feature = "polars_cloud")] -mod cloud; -mod conversion; -mod dataframe; -mod datatypes; -mod error; -mod exceptions; -mod expr; -mod file; -mod functions; -mod gil_once_cell; -mod interop; -mod lazyframe; -mod lazygroupby; -mod map; #[cfg(debug_assertions)] mod memory; -#[cfg(feature = "object")] -mod object; -#[cfg(feature = "object")] -mod on_startup; -mod prelude; -mod py_modules; -mod series; -#[cfg(feature = "sql")] -mod sql; -mod utils; - -use pyo3::prelude::*; -use pyo3::{wrap_pyfunction, wrap_pymodule}; -use crate::allocator::create_allocator_capsule; +use allocator::create_allocator_capsule; #[cfg(feature = "csv")] -use crate::batched_csv::PyBatchedCsv; -use crate::conversion::Wrap; -use crate::dataframe::PyDataFrame; -use crate::expr::PyExpr; -use crate::functions::PyStringCacheHolder; -use crate::lazyframe::{PyInProcessQuery, PyLazyFrame}; -use crate::lazygroupby::PyLazyGroupBy; -use crate::series::PySeries; +use polars_python::batched_csv::PyBatchedCsv; +#[cfg(feature = "polars_cloud")] +use polars_python::cloud; +use polars_python::dataframe::PyDataFrame; +use polars_python::expr::PyExpr; +use polars_python::functions::PyStringCacheHolder; +use polars_python::lazyframe::{PyInProcessQuery, PyLazyFrame}; +use polars_python::lazygroupby::PyLazyGroupBy; +use polars_python::series::PySeries; #[cfg(feature = "sql")] -use crate::sql::PySQLContext; +use polars_python::sql::PySQLContext; +use polars_python::{exceptions, functions}; +use pyo3::prelude::*; +use pyo3::{wrap_pyfunction, wrap_pymodule}; #[pymodule] fn _ir_nodes(_py: Python, m: &Bound) -> PyResult<()> { - use crate::lazyframe::visitor::nodes::*; + use polars_python::lazyframe::visitor::nodes::*; m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); @@ -84,8 +57,8 @@ fn _ir_nodes(_py: Python, m: &Bound) -> PyResult<()> { #[pymodule] fn _expr_nodes(_py: Python, m: &Bound) -> PyResult<()> { - use crate::lazyframe::visitor::expr_nodes::*; - use crate::lazyframe::PyExprIR; + use polars_python::lazyframe::visit::PyExprIR; + use polars_python::lazyframe::visitor::expr_nodes::*; // Expressions m.add_class::().unwrap(); m.add_class::().unwrap(); @@ -134,6 +107,7 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pymodule!(_ir_nodes))?; // Expr objects m.add_wrapped(wrap_pymodule!(_expr_nodes))?; + // Functions - eager m.add_wrapped(wrap_pyfunction!(functions::concat_df)) .unwrap(); @@ -300,7 +274,7 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::dtype_str_repr)) .unwrap(); #[cfg(feature = "object")] - m.add_wrapped(wrap_pyfunction!(on_startup::__register_startup_deps)) + m.add_wrapped(wrap_pyfunction!(functions::__register_startup_deps)) .unwrap(); // Functions - random diff --git a/py-polars/tests/benchmark/data/__init__.py b/py-polars/tests/benchmark/data/__init__.py index b7f246f37abc..255752458b72 100644 --- a/py-polars/tests/benchmark/data/__init__.py +++ b/py-polars/tests/benchmark/data/__init__.py @@ -1,6 +1,6 @@ """Data generation functionality for use in the benchmarking suite.""" from tests.benchmark.data.h2oai import generate_group_by_data -from tests.benchmark.data.tpch import load_tpch_table +from tests.benchmark.data.pdsh import load_pdsh_table -__all__ = ["load_tpch_table", "generate_group_by_data"] +__all__ = ["load_pdsh_table", "generate_group_by_data"] diff --git a/py-polars/tests/benchmark/data/pdsh/__init__.py b/py-polars/tests/benchmark/data/pdsh/__init__.py new file mode 100644 index 000000000000..ef007f5ed8d9 --- /dev/null +++ b/py-polars/tests/benchmark/data/pdsh/__init__.py @@ -0,0 +1,5 @@ +"""Generate data for the PDS-H benchmark tests.""" + +from tests.benchmark.data.pdsh.generate_data import load_pdsh_table + +__all__ = ["load_pdsh_table"] diff --git a/py-polars/tests/benchmark/data/tpch/dbgen/dbgen b/py-polars/tests/benchmark/data/pdsh/dbgen/dbgen similarity index 100% rename from py-polars/tests/benchmark/data/tpch/dbgen/dbgen rename to py-polars/tests/benchmark/data/pdsh/dbgen/dbgen diff --git a/py-polars/tests/benchmark/data/tpch/dbgen/dists.dss b/py-polars/tests/benchmark/data/pdsh/dbgen/dists.dss similarity index 100% rename from py-polars/tests/benchmark/data/tpch/dbgen/dists.dss rename to py-polars/tests/benchmark/data/pdsh/dbgen/dists.dss diff --git a/py-polars/tests/benchmark/data/tpch/generate_data.py b/py-polars/tests/benchmark/data/pdsh/generate_data.py similarity index 73% rename from py-polars/tests/benchmark/data/tpch/generate_data.py rename to py-polars/tests/benchmark/data/pdsh/generate_data.py index 3b4d81be51ff..0f8866311ada 100644 --- a/py-polars/tests/benchmark/data/tpch/generate_data.py +++ b/py-polars/tests/benchmark/data/pdsh/generate_data.py @@ -1,8 +1,19 @@ """ -Script to generate data for running the TPC-H benchmark. - -Data generation logic was adapted from the TPC-H benchmark tools: -https://www.tpc.org/tpch/ +Disclaimer. + +Certain portions of the contents of this file are derived from TPC-H version 3.0.1 +(retrieved from +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). +Such portions are subject to copyrights held by Transaction Processing +Performance Council (“TPC”) and licensed under the TPC EULA is available at +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) +(the “TPC EULA”). + +You may not use this file except in compliance with the TPC EULA. +DISCLAIMER: Portions of this file is derived from the TPC-H benchmark and as +such any result obtained using this file are not comparable to published TPC-H +Benchmark results, as the results obtained from using this file do not comply with +the TPC-H Benchmark. """ from __future__ import annotations @@ -19,12 +30,12 @@ CURRENT_DIR = Path(__file__).parent DBGEN_DIR = CURRENT_DIR / "dbgen" -__all__ = ["load_tpch_table"] +__all__ = ["load_pdsh_table"] -def load_tpch_table(table_name: str, scale_factor: float = 0.01) -> pl.DataFrame: +def load_pdsh_table(table_name: str, scale_factor: float = 0.01) -> pl.DataFrame: """ - Load a TPC-H table from disk. + Load a PDS-H table from disk. If the file does not exist, it is generated along with all other tables. """ @@ -32,16 +43,16 @@ def load_tpch_table(table_name: str, scale_factor: float = 0.01) -> pl.DataFrame file_path = folder / f"{table_name}.parquet" if not file_path.exists(): - _generate_tpch_data(scale_factor) + _generate_pdsh_data(scale_factor) return pl.read_parquet(file_path) -def _generate_tpch_data(scale_factor: float = 0.01) -> None: - """Generate all TPC-H datasets with the given scale factor.""" +def _generate_pdsh_data(scale_factor: float = 0.01) -> None: + """Generate all PDS-H datasets with the given scale factor.""" # TODO: Can we make this work under Windows? if sys.platform == "win32": - msg = "cannot generate TPC-H data under Windows" + msg = "cannot generate PDS-H data under Windows" raise RuntimeError(msg) subprocess.run(["./dbgen", "-f", "-v", "-s", str(scale_factor)], cwd=DBGEN_DIR) diff --git a/py-polars/tests/benchmark/data/tpch/__init__.py b/py-polars/tests/benchmark/data/tpch/__init__.py deleted file mode 100644 index 2973049f3fbd..000000000000 --- a/py-polars/tests/benchmark/data/tpch/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Generate data for the TPC-H benchmark tests.""" - -from tests.benchmark.data.tpch.generate_data import load_tpch_table - -__all__ = ["load_tpch_table"] diff --git a/py-polars/tests/benchmark/interop/test_numpy.py b/py-polars/tests/benchmark/interop/test_numpy.py index e44b40134558..5b516b0ecb19 100644 --- a/py-polars/tests/benchmark/interop/test_numpy.py +++ b/py-polars/tests/benchmark/interop/test_numpy.py @@ -18,19 +18,19 @@ def floats_array() -> np.ndarray[Any, Any]: return np.random.randn(n_rows) -@pytest.fixture() +@pytest.fixture def floats(floats_array: np.ndarray[Any, Any]) -> pl.Series: return pl.Series(floats_array) -@pytest.fixture() +@pytest.fixture def floats_with_nulls(floats: pl.Series) -> pl.Series: null_probability = 0.1 validity = pl.Series(np.random.uniform(size=floats.len())) > null_probability return pl.select(pl.when(validity).then(floats)).to_series() -@pytest.fixture() +@pytest.fixture def floats_chunked(floats_array: np.ndarray[Any, Any]) -> pl.Series: n_chunks = 5 chunk_len = len(floats_array) // n_chunks diff --git a/py-polars/tests/benchmark/test_join_where.py b/py-polars/tests/benchmark/test_join_where.py new file mode 100644 index 000000000000..d0bfd7d15b6d --- /dev/null +++ b/py-polars/tests/benchmark/test_join_where.py @@ -0,0 +1,73 @@ +"""Benchmark tests for join_where with inequality conditions.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl + +pytestmark = pytest.mark.benchmark() + + +def test_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None: + east, west = east_west + result = ( + east.lazy() + .join_where( + west.lazy(), + [pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost")], + ) + .collect() + ) + + assert len(result) > 0 + + +def test_non_strict_inequalities(east_west: tuple[pl.DataFrame, pl.DataFrame]) -> None: + east, west = east_west + result = ( + east.lazy() + .join_where( + west.lazy(), + [pl.col("dur") <= pl.col("time"), pl.col("rev") >= pl.col("cost")], + ) + .collect() + ) + + assert len(result) > 0 + + +@pytest.fixture(scope="module") +def east_west() -> tuple[pl.DataFrame, pl.DataFrame]: + num_rows_left, num_rows_right = 50_000, 5_000 + rng = np.random.default_rng(42) + + # Generate two separate datasets where revenue/cost are linearly related to + # duration/time, but add some noise to the west table so that there are some + # rows where the cost for the same or greater time will be less than the east table. + east_dur = rng.integers(1_000, 50_000, num_rows_left) + east_rev = (east_dur * 0.123).astype(np.int32) + west_time = rng.integers(1_000, 50_000, num_rows_right) + west_cost = west_time * 0.123 + west_cost += rng.normal(0.0, 1.0, num_rows_right) + west_cost = west_cost.astype(np.int32) + + east = pl.DataFrame( + { + "id": np.arange(0, num_rows_left), + "dur": east_dur, + "rev": east_rev, + "cores": rng.integers(1, 10, num_rows_left), + } + ) + west = pl.DataFrame( + { + "t_id": np.arange(0, num_rows_right), + "time": west_time, + "cost": west_cost, + "cores": rng.integers(1, 10, num_rows_right), + } + ) + + return east, west diff --git a/py-polars/tests/benchmark/test_tpch.py b/py-polars/tests/benchmark/test_pdsh.py similarity index 91% rename from py-polars/tests/benchmark/test_tpch.py rename to py-polars/tests/benchmark/test_pdsh.py index bd09d92161c9..2ee601b5b895 100644 --- a/py-polars/tests/benchmark/test_tpch.py +++ b/py-polars/tests/benchmark/test_pdsh.py @@ -1,58 +1,76 @@ +""" +Disclaimer. + +Certain portions of the contents of this file are derived from TPC-H version 3.0.1 +(retrieved from +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp). +Such portions are subject to copyrights held by Transaction Processing +Performance Council (“TPC”) and licensed under the TPC EULA is available at +http://www.tpc.org/tpc_documents_current_versions/current_specifications5.asp) +(the “TPC EULA”). + +You may not use this file except in compliance with the TPC EULA. +DISCLAIMER: Portions of this file is derived from the TPC-H benchmark and as +such any result obtained using this file are not comparable to published TPC-H +Benchmark results, as the results obtained from using this file do not comply with +the TPC-H Benchmark. +""" + import sys from datetime import date import pytest import polars as pl -from tests.benchmark.data import load_tpch_table +from tests.benchmark.data import load_pdsh_table if sys.platform == "win32": - pytest.skip("TPC-H data cannot be generated under Windows", allow_module_level=True) + pytest.skip("PDS-H data cannot be generated under Windows", allow_module_level=True) pytestmark = pytest.mark.benchmark() @pytest.fixture(scope="module") def customer() -> pl.LazyFrame: - return load_tpch_table("customer").lazy() + return load_pdsh_table("customer").lazy() @pytest.fixture(scope="module") def lineitem() -> pl.LazyFrame: - return load_tpch_table("lineitem").lazy() + return load_pdsh_table("lineitem").lazy() @pytest.fixture(scope="module") def nation() -> pl.LazyFrame: - return load_tpch_table("nation").lazy() + return load_pdsh_table("nation").lazy() @pytest.fixture(scope="module") def orders() -> pl.LazyFrame: - return load_tpch_table("orders").lazy() + return load_pdsh_table("orders").lazy() @pytest.fixture(scope="module") def part() -> pl.LazyFrame: - return load_tpch_table("part").lazy() + return load_pdsh_table("part").lazy() @pytest.fixture(scope="module") def partsupp() -> pl.LazyFrame: - return load_tpch_table("partsupp").lazy() + return load_pdsh_table("partsupp").lazy() @pytest.fixture(scope="module") def region() -> pl.LazyFrame: - return load_tpch_table("region").lazy() + return load_pdsh_table("region").lazy() @pytest.fixture(scope="module") def supplier() -> pl.LazyFrame: - return load_tpch_table("supplier").lazy() + return load_pdsh_table("supplier").lazy() -def test_tpch_q1(lineitem: pl.LazyFrame) -> None: +def test_pdsh_q1(lineitem: pl.LazyFrame) -> None: var1 = date(1998, 9, 2) q_final = ( @@ -81,7 +99,7 @@ def test_tpch_q1(lineitem: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q2( +def test_pdsh_q2( nation: pl.LazyFrame, part: pl.LazyFrame, partsupp: pl.LazyFrame, @@ -125,7 +143,7 @@ def test_tpch_q2( q_final.collect() -def test_tpch_q3( +def test_pdsh_q3( customer: pl.LazyFrame, lineitem: pl.LazyFrame, orders: pl.LazyFrame ) -> None: var1 = "BUILDING" @@ -154,7 +172,7 @@ def test_tpch_q3( q_final.collect() -def test_tpch_q4(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: +def test_pdsh_q4(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: var1 = date(1993, 7, 1) var2 = date(1993, 10, 1) @@ -170,7 +188,7 @@ def test_tpch_q4(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q5( +def test_pdsh_q5( customer: pl.LazyFrame, lineitem: pl.LazyFrame, nation: pl.LazyFrame, @@ -205,7 +223,7 @@ def test_tpch_q5( q_final.collect() -def test_tpch_q6(lineitem: pl.LazyFrame) -> None: +def test_pdsh_q6(lineitem: pl.LazyFrame) -> None: var1 = date(1994, 1, 1) var2 = date(1995, 1, 1) var3 = 0.05 @@ -225,7 +243,7 @@ def test_tpch_q6(lineitem: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q7( +def test_pdsh_q7( customer: pl.LazyFrame, lineitem: pl.LazyFrame, nation: pl.LazyFrame, @@ -274,7 +292,7 @@ def test_tpch_q7( q_final.collect() -def test_tpch_q8( +def test_pdsh_q8( customer: pl.LazyFrame, lineitem: pl.LazyFrame, nation: pl.LazyFrame, @@ -322,7 +340,7 @@ def test_tpch_q8( q_final.collect() -def test_tpch_q9( +def test_pdsh_q9( lineitem: pl.LazyFrame, nation: pl.LazyFrame, orders: pl.LazyFrame, @@ -357,7 +375,7 @@ def test_tpch_q9( q_final.collect() -def test_tpch_q10( +def test_pdsh_q10( customer: pl.LazyFrame, lineitem: pl.LazyFrame, nation: pl.LazyFrame, @@ -404,7 +422,7 @@ def test_tpch_q10( q_final.collect() -def test_tpch_q11( +def test_pdsh_q11( nation: pl.LazyFrame, partsupp: pl.LazyFrame, supplier: pl.LazyFrame ) -> None: var1 = "GERMANY" @@ -438,7 +456,7 @@ def test_tpch_q11( q_final.collect() -def test_tpch_q12(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: +def test_pdsh_q12(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: var1 = "MAIL" var2 = "SHIP" var3 = date(1994, 1, 1) @@ -467,7 +485,7 @@ def test_tpch_q12(lineitem: pl.LazyFrame, orders: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q13(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: +def test_pdsh_q13(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: var1 = "special" var2 = "requests" @@ -484,7 +502,7 @@ def test_tpch_q13(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q14(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: +def test_pdsh_q14(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: var1 = date(1995, 9, 1) var2 = date(1995, 10, 1) @@ -507,7 +525,7 @@ def test_tpch_q14(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q15(lineitem: pl.LazyFrame, supplier: pl.LazyFrame) -> None: +def test_pdsh_q15(lineitem: pl.LazyFrame, supplier: pl.LazyFrame) -> None: var1 = date(1996, 1, 1) var2 = date(1996, 4, 1) @@ -532,7 +550,7 @@ def test_tpch_q15(lineitem: pl.LazyFrame, supplier: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q16( +def test_pdsh_q16( part: pl.LazyFrame, partsupp: pl.LazyFrame, supplier: pl.LazyFrame ) -> None: var1 = "Brand#45" @@ -558,7 +576,7 @@ def test_tpch_q16( q_final.collect() -def test_tpch_q17(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: +def test_pdsh_q17(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: var1 = "Brand#23" var2 = "MED BOX" @@ -579,7 +597,7 @@ def test_tpch_q17(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q18( +def test_pdsh_q18( customer: pl.LazyFrame, lineitem: pl.LazyFrame, orders: pl.LazyFrame ) -> None: var1 = 300 @@ -608,7 +626,7 @@ def test_tpch_q18( q_final.collect() -def test_tpch_q19(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: +def test_pdsh_q19(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: q_final = ( part.join(lineitem, left_on="p_partkey", right_on="l_partkey") .filter(pl.col("l_shipmode").is_in(["AIR", "AIR REG"])) @@ -649,7 +667,7 @@ def test_tpch_q19(lineitem: pl.LazyFrame, part: pl.LazyFrame) -> None: q_final.collect() -def test_tpch_q20( +def test_pdsh_q20( lineitem: pl.LazyFrame, nation: pl.LazyFrame, part: pl.LazyFrame, @@ -687,7 +705,7 @@ def test_tpch_q20( q_final.collect() -def test_tpch_q21( +def test_pdsh_q21( lineitem: pl.LazyFrame, nation: pl.LazyFrame, orders: pl.LazyFrame, @@ -723,7 +741,7 @@ def test_tpch_q21( q_final.collect() -def test_tpch_q22(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: +def test_pdsh_q22(customer: pl.LazyFrame, orders: pl.LazyFrame) -> None: q1 = ( customer.with_columns(pl.col("c_phone").str.slice(0, 2).alias("cntrycode")) .filter(pl.col("cntrycode").str.contains("13|31|23|29|30|18|17")) diff --git a/py-polars/tests/docs/test_user_guide.py b/py-polars/tests/docs/test_user_guide.py index 08be6fe9dfbf..89203b4d09eb 100644 --- a/py-polars/tests/docs/test_user_guide.py +++ b/py-polars/tests/docs/test_user_guide.py @@ -29,7 +29,7 @@ def _change_test_dir() -> Iterator[None]: os.chdir(current_path) -@pytest.mark.docs() +@pytest.mark.docs @pytest.mark.parametrize("path", snippet_paths) @pytest.mark.usefixtures("_change_test_dir") def test_run_python_snippets(path: Path) -> None: diff --git a/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py b/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py index 825c2c130e57..d99bab04ef7d 100644 --- a/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py +++ b/py-polars/tests/unit/cloud/test_prepare_cloud_plan.py @@ -6,10 +6,9 @@ import polars as pl from polars._utils.cloud import prepare_cloud_plan -from polars.exceptions import InvalidOperationError +from polars.exceptions import ComputeError, InvalidOperationError CLOUD_SOURCE = "s3://my-nonexistent-bucket/dataset" -CLOUD_SINK = "s3://my-nonexistent-bucket/result" @pytest.mark.parametrize( @@ -22,44 +21,13 @@ ], ) def test_prepare_cloud_plan(lf: pl.LazyFrame) -> None: - result = prepare_cloud_plan(lf, CLOUD_SINK) + result = prepare_cloud_plan(lf) assert isinstance(result, bytes) deserialized = pl.LazyFrame.deserialize(BytesIO(result)) assert isinstance(deserialized, pl.LazyFrame) -def test_prepare_cloud_plan_sink_added() -> None: - lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) - - result = prepare_cloud_plan(lf, CLOUD_SINK) - - deserialized = pl.LazyFrame.deserialize(BytesIO(result)) - assert "SINK (cloud)" in deserialized.explain() - - -def test_prepare_cloud_plan_invalid_sink_uri() -> None: - lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) - local_path = "~/local/result.parquet" - - with pytest.raises(InvalidOperationError, match="non-cloud paths not supported"): - prepare_cloud_plan(lf, local_path) - - -def test_prepare_cloud_plan_optimization_toggle() -> None: - lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) - - with pytest.raises(TypeError, match="unexpected keyword argument"): - prepare_cloud_plan(lf, CLOUD_SINK, nonexistent_optimization=False) - - result = prepare_cloud_plan(lf, CLOUD_SINK, projection_pushdown=False) - assert isinstance(result, bytes) - - # TODO: How to check that this optimization was toggled correctly? - deserialized = pl.LazyFrame.deserialize(BytesIO(result)) - assert isinstance(deserialized, pl.LazyFrame) - - @pytest.mark.parametrize( "lf", [ @@ -69,12 +37,6 @@ def test_prepare_cloud_plan_optimization_toggle() -> None: pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).select( pl.col("b").map_batches(lambda x: sum(x)) ), - pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( - pl.col("a").name.map(lambda x: x.upper()) - ), - pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( - pl.col("a").name.map_fields(lambda x: x.upper()) - ), pl.LazyFrame({"a": [1, 2], "b": [3, 4]}).map_batches(lambda x: x), pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) .group_by("a") @@ -85,14 +47,31 @@ def test_prepare_cloud_plan_optimization_toggle() -> None: pl.scan_parquet(CLOUD_SOURCE).filter( pl.col("a") < pl.lit(1).map_elements(lambda x: x + 1) ), + pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int64) + ), ], ) -def test_prepare_cloud_plan_fail_on_udf(lf: pl.LazyFrame) -> None: - with pytest.raises( - InvalidOperationError, - match="logical plan ineligible for execution on Polars Cloud", - ): - prepare_cloud_plan(lf, CLOUD_SINK) +def test_prepare_cloud_plan_udf(lf: pl.LazyFrame) -> None: + result = prepare_cloud_plan(lf) + assert isinstance(result, bytes) + + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) + + +def test_prepare_cloud_plan_optimization_toggle() -> None: + lf = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) + + with pytest.raises(TypeError, match="unexpected keyword argument"): + prepare_cloud_plan(lf, nonexistent_optimization=False) + + result = prepare_cloud_plan(lf, projection_pushdown=False) + assert isinstance(result, bytes) + + # TODO: How to check that this optimization was toggled correctly? + deserialized = pl.LazyFrame.deserialize(BytesIO(result)) + assert isinstance(deserialized, pl.LazyFrame) @pytest.mark.parametrize( @@ -109,10 +88,10 @@ def test_prepare_cloud_plan_fail_on_local_data_source(lf: pl.LazyFrame) -> None: InvalidOperationError, match="logical plan ineligible for execution on Polars Cloud", ): - prepare_cloud_plan(lf, CLOUD_SINK) + prepare_cloud_plan(lf) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_prepare_cloud_plan_fail_on_python_scan(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) data_path = tmp_path / "data.parquet" @@ -124,4 +103,20 @@ def test_prepare_cloud_plan_fail_on_python_scan(tmp_path: Path) -> None: InvalidOperationError, match="logical plan ineligible for execution on Polars Cloud", ): - prepare_cloud_plan(lf, CLOUD_SINK) + prepare_cloud_plan(lf) + + +@pytest.mark.parametrize( + "lf", + [ + pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( + pl.col("a").name.map(lambda x: x.upper()) + ), + pl.LazyFrame({"a": [{"x": 1, "y": 2}]}).select( + pl.col("a").name.map_fields(lambda x: x.upper()) + ), + ], +) +def test_prepare_cloud_plan_fail_on_serialization(lf: pl.LazyFrame) -> None: + with pytest.raises(ComputeError, match="serialization not supported"): + prepare_cloud_plan(lf) diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py index 828ad87684cb..7335f8f46835 100644 --- a/py-polars/tests/unit/conftest.py +++ b/py-polars/tests/unit/conftest.py @@ -32,13 +32,13 @@ NESTED_DTYPES = [pl.List, pl.Struct, pl.Array] -@pytest.fixture() +@pytest.fixture def partition_limit() -> int: """The limit at which Polars will start partitioning in debug builds.""" return 15 -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: df = pl.DataFrame( { @@ -68,14 +68,14 @@ def df() -> pl.DataFrame: ) -@pytest.fixture() +@pytest.fixture def df_no_lists(df: pl.DataFrame) -> pl.DataFrame: return df.select( pl.all().exclude(["list_str", "list_int", "list_bool", "list_int", "list_flt"]) ) -@pytest.fixture() +@pytest.fixture def fruits_cars() -> pl.DataFrame: return pl.DataFrame( { @@ -88,7 +88,7 @@ def fruits_cars() -> pl.DataFrame: ) -@pytest.fixture() +@pytest.fixture def str_ints_df() -> pl.DataFrame: n = 1000 @@ -199,7 +199,7 @@ def get_peak(self) -> int: return tracemalloc.get_traced_memory()[1] -@pytest.fixture() +@pytest.fixture def memory_usage_without_pyarrow() -> Generator[MemoryUsage, Any, Any]: """ Provide an API for measuring peak memory usage. diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index ffda370aa538..fda072930525 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -1677,6 +1677,7 @@ def __arrow_c_array__(self, requested_schema: object = None) -> object: def test_pycapsule_interface(df: pl.DataFrame) -> None: + df = df.rechunk() pyarrow_table = df.to_arrow() # Array via C data interface diff --git a/py-polars/tests/unit/constructors/test_series.py b/py-polars/tests/unit/constructors/test_series.py index cfc0c76b6dc2..c31a5b48ce68 100644 --- a/py-polars/tests/unit/constructors/test_series.py +++ b/py-polars/tests/unit/constructors/test_series.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import pandas as pd import pytest import polars as pl @@ -148,6 +149,12 @@ def test_series_init_np_temporal_with_nat_15518() -> None: assert_series_equal(result, expected) +def test_series_init_pandas_timestamp_18127() -> None: + result = pl.Series([pd.Timestamp("2000-01-01T00:00:00.123456789", tz="UTC")]) + # Note: time unit is not (yet) respected, it should be Datetime('ns', 'UTC'). + assert result.dtype == pl.Datetime("us", "UTC") + + def test_series_init_np_2d_zero_zero_shape() -> None: arr = np.array([]).reshape(0, 0) with pytest.raises( diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index fe0b0fb04f82..c3472cc49d79 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1555,7 +1555,7 @@ def test_reproducible_hash_with_seeds() -> None: assert_series_equal(expected, result, check_names=False, check_exact=True) -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize( "e", [ @@ -2881,7 +2881,7 @@ def test_sum_empty_column_names() -> None: def test_flags() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}) assert df.flags == { "a": {"SORTED_ASC": False, "SORTED_DESC": False}, "b": {"SORTED_ASC": False, "SORTED_DESC": False}, diff --git a/py-polars/tests/unit/dataframe/test_from_dict.py b/py-polars/tests/unit/dataframe/test_from_dict.py index effdc47a4d4d..2b87e970d5cb 100644 --- a/py-polars/tests/unit/dataframe/test_from_dict.py +++ b/py-polars/tests/unit/dataframe/test_from_dict.py @@ -145,7 +145,7 @@ def test_from_dict_with_scalars() -> None: assert df9.rows() == [(0, 2, 0, "x"), (1, 1, 0, "x"), (2, 0, 0, "x")] -@pytest.mark.slow() +@pytest.mark.slow def test_from_dict_with_values_mixed() -> None: # a bit of everything mixed_dtype_data: dict[str, Any] = { diff --git a/py-polars/tests/unit/dataframe/test_partition_by.py b/py-polars/tests/unit/dataframe/test_partition_by.py index a793e19dccc1..c78a7043d4b7 100644 --- a/py-polars/tests/unit/dataframe/test_partition_by.py +++ b/py-polars/tests/unit/dataframe/test_partition_by.py @@ -6,7 +6,7 @@ import polars.selectors as cs -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: return pl.DataFrame( { @@ -75,7 +75,7 @@ def test_partition_by_as_dict_include_keys_false_maintain_order_false() -> None: df.partition_by(["a"], maintain_order=False, include_key=False, as_dict=True) -@pytest.mark.slow() +@pytest.mark.slow def test_partition_by_as_dict_include_keys_false_large() -> None: # test with both as_dict and include_key=False df = pl.DataFrame( diff --git a/py-polars/tests/unit/dataframe/test_serde.py b/py-polars/tests/unit/dataframe/test_serde.py index ea17420eb831..29d4eb5b05a6 100644 --- a/py-polars/tests/unit/dataframe/test_serde.py +++ b/py-polars/tests/unit/dataframe/test_serde.py @@ -62,9 +62,9 @@ def test_df_serde_json_stringio(df: pl.DataFrame) -> None: def test_df_serialize_json() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a") + df = pl.DataFrame({"a": [1, 2, 3], "b": [9, 5, 6]}).sort("a") result = df.serialize(format="json") - expected = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"SORTED_ASC","values":[1,2,3]},{"name":"b","datatype":"Int64","bit_settings":"","values":[4,5,6]}]}' + expected = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"SORTED_ASC","values":[1,2,3]},{"name":"b","datatype":"Int64","bit_settings":"","values":[9,5,6]}]}' assert result == expected @@ -85,7 +85,7 @@ def test_df_serde_to_from_buffer( assert_frame_equal(df, read_df, categorical_as_str=True) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) diff --git a/py-polars/tests/unit/dataframe/test_upsample.py b/py-polars/tests/unit/dataframe/test_upsample.py index 163245fb1502..21160ad54df8 100644 --- a/py-polars/tests/unit/dataframe/test_upsample.py +++ b/py-polars/tests/unit/dataframe/test_upsample.py @@ -216,3 +216,71 @@ def test_upsample_index_invalid( every="1h", maintain_order=maintain_order, ) + + +def test_upsample_sorted_only_within_group() -> None: + df = pl.DataFrame( + { + "time": [ + datetime(2021, 4, 1), + datetime(2021, 2, 1), + datetime(2021, 5, 1), + datetime(2021, 6, 1), + ], + "admin": ["Netherlands", "Åland", "Åland", "Netherlands"], + "test2": [1, 0, 2, 3], + } + ) + + up = df.upsample( + time_column="time", + every="1mo", + group_by="admin", + maintain_order=True, + ).select(pl.all().forward_fill()) + + expected = pl.DataFrame( + { + "time": [ + datetime(2021, 4, 1, 0, 0), + datetime(2021, 5, 1, 0, 0), + datetime(2021, 6, 1, 0, 0), + datetime(2021, 2, 1, 0, 0), + datetime(2021, 3, 1, 0, 0), + datetime(2021, 4, 1, 0, 0), + datetime(2021, 5, 1, 0, 0), + ], + "admin": [ + "Netherlands", + "Netherlands", + "Netherlands", + "Åland", + "Åland", + "Åland", + "Åland", + ], + "test2": [1, 1, 3, 0, 0, 0, 2], + } + ) + + assert_frame_equal(up, expected) + + +def test_upsample_sorted_only_within_group_but_no_group_by_provided() -> None: + df = pl.DataFrame( + { + "time": [ + datetime(2021, 4, 1), + datetime(2021, 2, 1), + datetime(2021, 5, 1), + datetime(2021, 6, 1), + ], + "admin": ["Netherlands", "Åland", "Åland", "Netherlands"], + "test2": [1, 0, 2, 3], + } + ) + with pytest.raises( + InvalidOperationError, + match=r"argument in operation 'upsample' is not sorted, please sort the 'expr/series/column' first", + ): + df.upsample(time_column="time", every="1mo") diff --git a/py-polars/tests/unit/dataframe/test_vstack.py b/py-polars/tests/unit/dataframe/test_vstack.py index 138808225a1d..8649471f6099 100644 --- a/py-polars/tests/unit/dataframe/test_vstack.py +++ b/py-polars/tests/unit/dataframe/test_vstack.py @@ -5,12 +5,12 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def df1() -> pl.DataFrame: return pl.DataFrame({"foo": [1, 2], "bar": [6, 7], "ham": ["a", "b"]}) -@pytest.fixture() +@pytest.fixture def df2() -> pl.DataFrame: return pl.DataFrame({"foo": [3, 4], "bar": [8, 9], "ham": ["c", "d"]}) diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index 0ded3ee282f6..38e5e5c154fe 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -176,7 +176,7 @@ def test_cast_list_to_array(data: Any, inner_type: pl.DataType) -> None: assert s.to_list() == data -@pytest.fixture() +@pytest.fixture def data_dispersion() -> pl.DataFrame: return pl.DataFrame( { diff --git a/py-polars/tests/unit/datatypes/test_bool.py b/py-polars/tests/unit/datatypes/test_bool.py index 34a4b0d589a6..981167838e56 100644 --- a/py-polars/tests/unit/datatypes/test_bool.py +++ b/py-polars/tests/unit/datatypes/test_bool.py @@ -6,7 +6,7 @@ import polars as pl -@pytest.mark.slow() +@pytest.mark.slow def test_bool_arg_min_max() -> None: # masks that ensures we take more than u64 chunks # and slicing and dicing to ensure the offsets work diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index bd9f0524824d..b898c7c07999 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -480,7 +480,7 @@ def test_cast_inner_categorical() -> None: ) -@pytest.mark.slow() +@pytest.mark.slow def test_stringcache() -> None: N = 1_500 with pl.StringCache(): diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 64b0d5fe5068..dd61f150ef75 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -29,7 +29,7 @@ def permutations_int_dec_none() -> list[tuple[D | int | None, ...]]: ) -@pytest.mark.slow() +@pytest.mark.slow def test_series_from_pydecimal_and_ints( permutations_int_dec_none: list[tuple[D | int | None, ...]], ) -> None: @@ -45,7 +45,7 @@ def test_series_from_pydecimal_and_ints( assert s.to_list() == [D(x) if x is not None else None for x in data] -@pytest.mark.slow() +@pytest.mark.slow def test_frame_from_pydecimal_and_ints( permutations_int_dec_none: list[tuple[D | int | None, ...]], monkeypatch: Any ) -> None: diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 173502c30a97..4607cfa89426 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -667,7 +667,7 @@ def test_as_list_logical_type() -> None: ).to_dict(as_series=False) == {"literal": [True], "timestamp": [[date(2000, 1, 1)]]} -@pytest.fixture() +@pytest.fixture def data_dispersion() -> pl.DataFrame: return pl.DataFrame( { diff --git a/py-polars/tests/unit/datatypes/test_parse.py b/py-polars/tests/unit/datatypes/test_parse.py index 1eb4f6325525..c95763033b32 100644 --- a/py-polars/tests/unit/datatypes/test_parse.py +++ b/py-polars/tests/unit/datatypes/test_parse.py @@ -30,7 +30,7 @@ def assert_dtype_equal(left: PolarsDataType, right: PolarsDataType) -> None: assert left == right - assert type(left) == type(right) + assert type(left) is type(right) assert hash(left) == hash(right) diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 265cc71d07c4..49a223f76fd4 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -623,7 +623,7 @@ def test_struct_categorical_5843() -> None: def test_empty_struct() -> None: # List df = pl.DataFrame({"a": [[{}]]}) - assert df.to_dict(as_series=False) == {"a": [[{"": None}]]} + assert df.to_dict(as_series=False) == {"a": [[None]]} # Struct one not empty df = pl.DataFrame({"a": [[{}, {"a": 10}]]}) @@ -631,7 +631,7 @@ def test_empty_struct() -> None: # Empty struct df = pl.DataFrame({"a": [{}]}) - assert df.to_dict(as_series=False) == {"a": [{"": None}]} + assert df.to_dict(as_series=False) == {"a": [None]} @pytest.mark.parametrize( @@ -710,7 +710,7 @@ def test_struct_null_cast() -> None: .lazy() .select([pl.lit(None, dtype=pl.Null).cast(dtype, strict=True)]) .collect() - ).to_dict(as_series=False) == {"literal": [{"a": None, "b": None, "c": None}]} + ).to_dict(as_series=False) == {"literal": [None]} def test_nested_struct_in_lists_cast() -> None: @@ -976,3 +976,50 @@ def test_named_exprs() -> None: res = df.select(pl.struct(schema=schema, b=pl.col("a"))) assert res.to_dict(as_series=False) == {"b": [{"b": 1}]} assert res.schema["b"] == pl.Struct(schema) + + +def test_struct_outer_nullability_zip_18119() -> None: + df = pl.Series("int", [0, 1, 2, 3], dtype=pl.Int64).to_frame() + assert df.lazy().with_columns( + result=pl.when(pl.col("int") >= 1).then( + pl.struct( + a=pl.when(pl.col("int") % 2 == 1).then(True), + b=pl.when(pl.col("int") >= 2).then(False), + ) + ) + ).collect().to_dict(as_series=False) == { + "int": [0, 1, 2, 3], + "result": [ + None, + {"a": True, "b": None}, + {"a": None, "b": False}, + {"a": True, "b": False}, + ], + } + + +def test_struct_group_by_shift_18107() -> None: + df_in = pl.DataFrame( + { + "group": [1, 1, 1, 2, 2, 2], + "id": [1, 2, 3, 4, 5, 6], + "value": [ + {"lon": 20, "lat": 10}, + {"lon": 30, "lat": 20}, + {"lon": 40, "lat": 30}, + {"lon": 50, "lat": 40}, + {"lon": 60, "lat": 50}, + {"lon": 70, "lat": 60}, + ], + } + ) + + assert df_in.group_by("group", maintain_order=True).agg( + pl.col("value").shift(-1) + ).to_dict(as_series=False) == { + "group": [1, 2], + "value": [ + [{"lon": 30, "lat": 20}, {"lon": 40, "lat": 30}, None], + [{"lon": 60, "lat": 50}, {"lon": 70, "lat": 60}, None], + ], + } diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index d08008bc8f41..ea1798fe7114 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -2294,12 +2294,13 @@ def test_weekday_vs_stdlib_datetime( ) -> None: result = ( pl.Series([value], dtype=pl.Datetime(time_unit)) - .dt.replace_time_zone(time_zone) + .dt.replace_time_zone(time_zone, non_existent="null", ambiguous="null") .dt.weekday() .item() ) - expected = value.isoweekday() - assert result == expected + if result is not None: + expected = value.isoweekday() + assert result == expected @given( diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index e234862d9b9e..a8f874e42548 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -4,7 +4,6 @@ from itertools import permutations from typing import TYPE_CHECKING, Any, cast -import numpy as np import pytest import polars as pl @@ -334,121 +333,6 @@ def test_arr_contains() -> None: } -def test_rank() -> None: - df = pl.DataFrame( - { - "a": [1, 1, 2, 2, 3], - } - ) - - s = df.select(pl.col("a").rank(method="average").alias("b")).to_series() - assert s.to_list() == [1.5, 1.5, 3.5, 3.5, 5.0] - assert s.dtype == pl.Float64 - - s = df.select(pl.col("a").rank(method="max").alias("b")).to_series() - assert s.to_list() == [2, 2, 4, 4, 5] - assert s.dtype == pl.get_index_type() - - -def test_rank_so_4109() -> None: - # also tests ranks null behavior - df = pl.from_dict( - { - "id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], - "rank": [None, 3, 2, 4, 1, 4, 3, 2, 1, None, 3, 4, 4, 1, None, 3], - } - ).sort(by=["id", "rank"]) - - assert df.group_by("id").agg( - [ - pl.col("rank").alias("original"), - pl.col("rank").rank(method="dense").alias("dense"), - pl.col("rank").rank(method="average").alias("average"), - ] - ).to_dict(as_series=False) == { - "id": [1, 2, 3, 4], - "original": [[None, 2, 3, 4], [1, 2, 3, 4], [None, 1, 3, 4], [None, 1, 3, 4]], - "dense": [[None, 1, 2, 3], [1, 2, 3, 4], [None, 1, 2, 3], [None, 1, 2, 3]], - "average": [ - [None, 1.0, 2.0, 3.0], - [1.0, 2.0, 3.0, 4.0], - [None, 1.0, 2.0, 3.0], - [None, 1.0, 2.0, 3.0], - ], - } - - -def test_rank_string_null_11252() -> None: - rank = pl.Series([None, "", "z", None, "a"]).rank() - assert rank.to_list() == [None, 1.0, 3.0, None, 2.0] - - -def test_search_sorted() -> None: - for seed in [1, 2, 3]: - np.random.seed(seed) - arr = np.sort(np.random.randn(10) * 100) - s = pl.Series(arr) - - for v in range(int(np.min(arr)), int(np.max(arr)), 20): - assert np.searchsorted(arr, v) == s.search_sorted(v) - - a = pl.Series([1, 2, 3]) - b = pl.Series([1, 2, 2, -1]) - assert a.search_sorted(b).to_list() == [0, 1, 1, 0] - b = pl.Series([1, 2, 2, None, 3]) - assert a.search_sorted(b).to_list() == [0, 1, 1, 0, 2] - - a = pl.Series(["b", "b", "d", "d"]) - b = pl.Series(["a", "b", "c", "d", "e"]) - assert a.search_sorted(b, side="left").to_list() == [0, 0, 2, 2, 4] - assert a.search_sorted(b, side="right").to_list() == [0, 2, 2, 4, 4] - - a = pl.Series([1, 1, 4, 4]) - b = pl.Series([0, 1, 2, 4, 5]) - assert a.search_sorted(b, side="left").to_list() == [0, 0, 2, 2, 4] - assert a.search_sorted(b, side="right").to_list() == [0, 2, 2, 4, 4] - - -def test_search_sorted_multichunk() -> None: - for seed in [1, 2, 3]: - np.random.seed(seed) - arr = np.sort(np.random.randn(10) * 100) - q = len(arr) // 4 - a, b, c, d = map( - pl.Series, (arr[:q], arr[q : 2 * q], arr[2 * q : 3 * q], arr[3 * q :]) - ) - s = pl.concat([a, b, c, d], rechunk=False) - assert s.n_chunks() == 4 - - for v in range(int(np.min(arr)), int(np.max(arr)), 20): - assert np.searchsorted(arr, v) == s.search_sorted(v) - - a = pl.concat( - [ - pl.Series([None, None, None], dtype=pl.Int64), - pl.Series([None, 1, 1, 2, 3]), - pl.Series([4, 4, 5, 6, 7, 8, 8]), - ], - rechunk=False, - ) - assert a.n_chunks() == 3 - b = pl.Series([-10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, None]) - left_ref = pl.Series( - [4, 4, 4, 6, 7, 8, 10, 11, 12, 13, 15, 0], dtype=pl.get_index_type() - ) - right_ref = pl.Series( - [4, 4, 6, 7, 8, 10, 11, 12, 13, 15, 15, 4], dtype=pl.get_index_type() - ) - assert_series_equal(a.search_sorted(b, side="left"), left_ref) - assert_series_equal(a.search_sorted(b, side="right"), right_ref) - - -def test_search_sorted_right_nulls() -> None: - a = pl.Series([1, 2, None, None]) - assert a.search_sorted(None, side="left") == 2 - assert a.search_sorted(None, side="right") == 4 - - def test_logical_boolean() -> None: # note, cannot use expressions in logical # boolean context (eg: and/or/not operators) diff --git a/py-polars/tests/unit/functions/range/test_date_range.py b/py-polars/tests/unit/functions/range/test_date_range.py index 0c15adae778a..a881d30c1e41 100644 --- a/py-polars/tests/unit/functions/range/test_date_range.py +++ b/py-polars/tests/unit/functions/range/test_date_range.py @@ -310,3 +310,17 @@ def test_date_ranges_datetime_input() -> None: "literal", [[date(2022, 1, 1), date(2022, 1, 2), date(2022, 1, 3)]] ) assert_series_equal(result, expected) + + +def test_date_range_with_subclass_18470_18447() -> None: + class MyAmazingDate(date): + pass + + class MyAmazingDatetime(datetime): + pass + + result = pl.datetime_range( + MyAmazingDate(2020, 1, 1), MyAmazingDatetime(2020, 1, 2), eager=True + ) + expected = pl.Series("literal", [datetime(2020, 1, 1), datetime(2020, 1, 2)]) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_concat.py b/py-polars/tests/unit/functions/test_concat.py index 8e7c4c9f31e3..e73c7fc9997f 100644 --- a/py-polars/tests/unit/functions/test_concat.py +++ b/py-polars/tests/unit/functions/test_concat.py @@ -3,7 +3,7 @@ import polars as pl -@pytest.mark.slow() +@pytest.mark.slow def test_concat_expressions_stack_overflow() -> None: n = 10000 e = pl.concat([pl.lit(x) for x in range(n)]) @@ -12,7 +12,7 @@ def test_concat_expressions_stack_overflow() -> None: assert df.shape == (n, 1) -@pytest.mark.slow() +@pytest.mark.slow def test_concat_lf_stack_overflow() -> None: n = 1000 bar = pl.DataFrame({"a": 0}).lazy() diff --git a/py-polars/tests/unit/functions/test_functions.py b/py-polars/tests/unit/functions/test_functions.py index a5ee29ae530f..de7e49574393 100644 --- a/py-polars/tests/unit/functions/test_functions.py +++ b/py-polars/tests/unit/functions/test_functions.py @@ -254,7 +254,7 @@ def test_align_frames() -> None: assert_frame_equal(pl_dot, pl.from_pandas(pd_dot)) pd.testing.assert_frame_equal(pd_dot, pl_dot.to_pandas()) - # (also: confirm alignment function works with lazyframes) + # confirm alignment function works with lazy frames lf1, lf2 = pl.align_frames( pl.from_pandas(pdf1.reset_index()).lazy(), pl.from_pandas(pdf2.reset_index()).lazy(), @@ -264,7 +264,7 @@ def test_align_frames() -> None: assert_frame_equal(lf1.collect(), pf1) assert_frame_equal(lf2.collect(), pf2) - # misc + # misc: no frames results in an empty list assert pl.align_frames(on="date") == [] # expected error condition @@ -275,6 +275,8 @@ def test_align_frames() -> None: on="date", ) + +def test_align_frames_misc() -> None: # descending result df1 = pl.DataFrame([[3, 5, 6], [5, 8, 9]], orient="row") df2 = pl.DataFrame([[2, 5, 6], [3, 8, 9], [4, 2, 0]], orient="row") @@ -290,6 +292,19 @@ def test_align_frames() -> None: assert pf.rows() == [(5, None, None), (4, 2, 0), (3, 8, 9), (2, 5, 6)] +def test_align_frames_with_nulls() -> None: + df1 = pl.DataFrame({"key": ["x", "y", None], "value": [1, 2, 0]}) + df2 = pl.DataFrame({"key": ["x", None, "z", "y"], "value": [4, 3, 6, 5]}) + + a1, a2 = pl.align_frames(df1, df2, on="key") + + aligned_frame_data = a1.to_dict(as_series=False), a2.to_dict(as_series=False) + assert aligned_frame_data == ( + {"key": [None, "x", "y", "z"], "value": [0, 1, 2, None]}, + {"key": [None, "x", "y", "z"], "value": [3, 4, 5, 6]}, + ) + + def test_align_frames_duplicate_key() -> None: # setup some test frames with duplicate key/alignment values df1 = pl.DataFrame({"x": ["a", "a", "a", "e"], "y": [1, 2, 4, 5]}) diff --git a/py-polars/tests/unit/functions/test_lit.py b/py-polars/tests/unit/functions/test_lit.py index 4b8ca8abe796..1f13ba122825 100644 --- a/py-polars/tests/unit/functions/test_lit.py +++ b/py-polars/tests/unit/functions/test_lit.py @@ -1,7 +1,7 @@ from __future__ import annotations import enum -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from decimal import Decimal from typing import TYPE_CHECKING, Any @@ -173,6 +173,17 @@ def test_lit_decimal() -> None: assert result == value +def test_lit_string_float() -> None: + value = 3.2 + + expr = pl.lit(value, dtype=pl.Utf8) + df = pl.select(expr) + result = df.item() + + assert df.dtypes[0] == pl.String + assert result == str(value) + + @given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal)) def test_lit_decimal_parametric(s: pl.Series) -> None: scale = s.dtype.scale # type: ignore[attr-defined] @@ -184,3 +195,27 @@ def test_lit_decimal_parametric(s: pl.Series) -> None: assert df.dtypes[0] == pl.Decimal(None, scale) assert result == value + + +def test_lit_datetime_subclass_w_allow_object() -> None: + class MyAmazingDate(date): + pass + + class MyAmazingDatetime(datetime): + pass + + result = pl.select( + a=pl.lit(MyAmazingDatetime(2020, 1, 1)), + b=pl.lit(MyAmazingDate(2020, 1, 1)), + c=pl.lit(MyAmazingDatetime(2020, 1, 1), allow_object=True), + d=pl.lit(MyAmazingDate(2020, 1, 1), allow_object=True), + ) + expected = pl.DataFrame( + { + "a": [datetime(2020, 1, 1)], + "b": [date(2020, 1, 1)], + "c": [datetime(2020, 1, 1)], + "d": [date(2020, 1, 1)], + } + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index 0690782bbad9..6f70b9c1b9d1 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -529,7 +529,7 @@ def test_when_then_null_broadcast() -> None: ) -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("len", [1, 10, 100, 500]) @pytest.mark.parametrize( ("dtype", "vals"), diff --git a/py-polars/tests/unit/interchange/test_from_dataframe.py b/py-polars/tests/unit/interchange/test_from_dataframe.py index 62f34666f8b1..35fcc595451a 100644 --- a/py-polars/tests/unit/interchange/test_from_dataframe.py +++ b/py-polars/tests/unit/interchange/test_from_dataframe.py @@ -406,13 +406,13 @@ def test_construct_offsets_buffer_copy() -> None: assert_series_equal(result, expected) -@pytest.fixture() +@pytest.fixture def bitmask() -> PolarsBuffer: data = pl.Series([False, True, True, False]) return PolarsBuffer(data) -@pytest.fixture() +@pytest.fixture def bytemask() -> PolarsBuffer: data = pl.Series([0, 1, 1, 0], dtype=pl.UInt8) return PolarsBuffer(data) diff --git a/py-polars/tests/unit/interop/test_from_pandas.py b/py-polars/tests/unit/interop/test_from_pandas.py index 0e3e985bf7e3..c22d8abafd21 100644 --- a/py-polars/tests/unit/interop/test_from_pandas.py +++ b/py-polars/tests/unit/interop/test_from_pandas.py @@ -22,6 +22,17 @@ def test_index_not_silently_excluded() -> None: pl.from_pandas(df, include_index=True) +def test_nameless_multiindex_doesnt_raise_with_include_index_false_18130() -> None: + df = pd.DataFrame( + range(4), + columns=["A"], + index=pd.MultiIndex.from_product((["C", "D"], [3, 4])), + ) + result = pl.from_pandas(df) + expected = pl.DataFrame({"A": [0, 1, 2, 3]}) + assert_frame_equal(result, expected) + + def test_from_pandas() -> None: df = pd.DataFrame( { @@ -310,7 +321,7 @@ def test_untrusted_categorical_input() -> None: assert_frame_equal(result, expected, categorical_as_str=True) -@pytest.fixture() +@pytest.fixture def _set_pyarrow_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( "polars._utils.construction.dataframe._PYARROW_AVAILABLE", False diff --git a/py-polars/tests/unit/io/cloud/test_aws.py b/py-polars/tests/unit/io/cloud/test_aws.py index 8dd8acef795f..9a004a7d825f 100644 --- a/py-polars/tests/unit/io/cloud/test_aws.py +++ b/py-polars/tests/unit/io/cloud/test_aws.py @@ -49,7 +49,7 @@ def s3_base(monkeypatch_module: Any) -> Iterator[str]: p.kill() -@pytest.fixture() +@pytest.fixture def s3(s3_base: str, io_files_path: Path) -> str: region = "us-east-1" client = boto3.client("s3", region_name=region, endpoint_url=s3_base) @@ -99,6 +99,6 @@ def test_lazy_count_s3(s3: str) -> None: "s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3} ).select(pl.len()) - assert "FAST COUNT(*)" in lf.explain() + assert "FAST_COUNT" in lf.explain() expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32}) assert_frame_equal(lf.collect(), expected) diff --git a/py-polars/tests/unit/io/cloud/test_cloud.py b/py-polars/tests/unit/io/cloud/test_cloud.py index baf08d2b474a..f943ab5e2c26 100644 --- a/py-polars/tests/unit/io/cloud/test_cloud.py +++ b/py-polars/tests/unit/io/cloud/test_cloud.py @@ -4,7 +4,7 @@ from polars.exceptions import ComputeError -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("format", ["parquet", "csv", "ndjson", "ipc"]) def test_scan_nonexistent_cloud_path_17444(format: str) -> None: # https://github.com/pola-rs/polars/issues/17444 diff --git a/py-polars/tests/unit/io/conftest.py b/py-polars/tests/unit/io/conftest.py index fd174486b25f..df245c097be4 100644 --- a/py-polars/tests/unit/io/conftest.py +++ b/py-polars/tests/unit/io/conftest.py @@ -5,6 +5,6 @@ import pytest -@pytest.fixture() +@pytest.fixture def io_files_path() -> Path: return Path(__file__).parent / "files" diff --git a/py-polars/tests/unit/io/database/conftest.py b/py-polars/tests/unit/io/database/conftest.py index 2ff027329ee0..8fcdc27547a3 100644 --- a/py-polars/tests/unit/io/database/conftest.py +++ b/py-polars/tests/unit/io/database/conftest.py @@ -10,7 +10,7 @@ from pathlib import Path -@pytest.fixture() +@pytest.fixture def tmp_sqlite_db(tmp_path: Path) -> Path: test_db = tmp_path / "test.db" test_db.unlink(missing_ok=True) @@ -51,7 +51,7 @@ def convert_date(val: bytes) -> date: return test_db -@pytest.fixture() +@pytest.fixture def tmp_sqlite_inference_db(tmp_path: Path) -> Path: test_db = tmp_path / "test_inference.db" test_db.unlink(missing_ok=True) diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index 3157e09cc3dd..b50e88602dac 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -18,7 +18,7 @@ import polars as pl from polars._utils.various import parse_version -from polars.exceptions import UnsuitableSQLError +from polars.exceptions import DuplicateError, UnsuitableSQLError from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY from polars.testing import assert_frame_equal @@ -150,7 +150,7 @@ class ExceptionTestParams(NamedTuple): kwargs: dict[str, Any] | None = None -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ( "read_method", @@ -678,6 +678,23 @@ def test_read_database_exceptions( read_database(**params) +@pytest.mark.parametrize( + "query", + [ + "SELECT 1, 1 FROM test_data", + 'SELECT 1 AS "n", 2 AS "n" FROM test_data', + 'SELECT name, value AS "name" FROM test_data', + ], +) +def test_read_database_duplicate_column_error(tmp_sqlite_db: Path, query: str) -> None: + alchemy_conn = create_engine(f"sqlite:///{tmp_sqlite_db}").connect() + with pytest.raises( + DuplicateError, + match="column .+ appears more than once in the query/result cursor", + ): + pl.read_database(query, connection=alchemy_conn) + + @pytest.mark.parametrize( "uri", [ @@ -686,7 +703,7 @@ def test_read_database_exceptions( ], ) def test_read_database_cx_credentials(uri: str) -> None: - if sys.version_info > (3, 11): + if sys.version_info > (3, 9, 4): # slightly different error on more recent Python versions with pytest.raises(RuntimeError, match=r"Source.*not supported"): pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx") @@ -698,7 +715,7 @@ def test_read_database_cx_credentials(uri: str) -> None: pl.read_database_uri("SELECT * FROM data", uri=uri, engine="connectorx") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: import kuzu diff --git a/py-polars/tests/unit/io/database/test_write.py b/py-polars/tests/unit/io/database/test_write.py index 1a995e31df64..da9550d7126a 100644 --- a/py-polars/tests/unit/io/database/test_write.py +++ b/py-polars/tests/unit/io/database/test_write.py @@ -18,7 +18,7 @@ from polars._typing import DbWriteEngine -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ("engine", "uri_connection"), [ @@ -237,7 +237,7 @@ def test_write_database_errors( df.write_database(connection=True, table_name="misc") # type: ignore[arg-type] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_database_using_sa_session(tmp_path: str) -> None: df = pl.DataFrame( { @@ -261,7 +261,7 @@ def test_write_database_using_sa_session(tmp_path: str) -> None: assert_frame_equal(result, df) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("pass_connection", [True, False]) def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> None: df = pl.DataFrame( @@ -291,7 +291,7 @@ def test_write_database_sa_rollback(tmp_path: str, pass_connection: bool) -> Non assert count == 0 -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("pass_connection", [True, False]) def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None: df = pl.DataFrame( @@ -318,3 +318,39 @@ def test_write_database_sa_commit(tmp_path: str, pass_connection: bool) -> None: ) assert_frame_equal(result, df) + + +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc not available on Windows or <= Python 3.8", +) +def test_write_database_adbc_temporary_table() -> None: + """Confirm that execution_options are passed along to create temporary tables.""" + df = pl.DataFrame({"colx": [1, 2, 3]}) + temp_tbl_name = "should_be_temptable" + expected_temp_table_create_sql = ( + """CREATE TABLE "should_be_temptable" ("colx" INTEGER)""" + ) + + # test with sqlite in memory + conn = _open_adbc_connection("sqlite:///:memory:") + assert ( + df.write_database( + temp_tbl_name, + connection=conn, + if_table_exists="fail", + engine_options={"temporary": True}, + ) + == 3 + ) + temp_tbl_sql_df = pl.read_database( + "select sql from sqlite_temp_master where type='table' and tbl_name = ?", + connection=conn, + execute_options={"parameters": [temp_tbl_name]}, + ) + assert temp_tbl_sql_df.shape[0] == 1, "no temp table created" + actual_temp_table_create_sql = temp_tbl_sql_df["sql"][0] + assert expected_temp_table_create_sql == actual_temp_table_create_sql + + if hasattr(conn, "close"): + conn.close() diff --git a/py-polars/tests/unit/io/test_avro.py b/py-polars/tests/unit/io/test_avro.py index 8a622f041c6e..bfd960a60228 100644 --- a/py-polars/tests/unit/io/test_avro.py +++ b/py-polars/tests/unit/io/test_avro.py @@ -17,7 +17,7 @@ COMPRESSIONS = ["uncompressed", "snappy", "deflate"] -@pytest.fixture() +@pytest.fixture def example_df() -> pl.DataFrame: return pl.DataFrame({"i64": [1, 2], "f64": [0.1, 0.2], "str": ["a", "b"]}) @@ -32,7 +32,7 @@ def test_from_to_buffer(example_df: pl.DataFrame, compression: AvroCompression) assert_frame_equal(example_df, read_df) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("compression", COMPRESSIONS) def test_from_to_file( example_df: pl.DataFrame, compression: AvroCompression, tmp_path: Path diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 3a0c0e2450e9..fcacedead1d4 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -28,7 +28,7 @@ from tests.unit.conftest import MemoryUsage -@pytest.fixture() +@pytest.fixture def foods_file_path(io_files_path: Path) -> Path: return io_files_path / "foods1.csv" @@ -85,7 +85,7 @@ def test_to_from_buffer(df_no_lists: pl.DataFrame) -> None: assert_frame_equal(df.select("time", "cat"), read_df, categorical_as_str=True) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_to_from_file(df_no_lists: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -455,7 +455,7 @@ def test_read_csv_buffer_ownership() -> None: assert buf.read() == bts -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_read_csv_encoding(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -643,12 +643,12 @@ def test_empty_line_with_multiple_columns() -> None: def test_preserve_whitespace_at_line_start() -> None: df = pl.read_csv( - b"a\n b \n c\nd", + b" a\n b \n c\nd", new_columns=["A"], has_header=False, use_pyarrow=False, ) - expected = pl.DataFrame({"A": ["a", " b ", " c", "d"]}) + expected = pl.DataFrame({"A": [" a", " b ", " c", "d"]}) assert_frame_equal(df, expected) @@ -953,6 +953,7 @@ def test_write_csv_separator() -> None: df.write_csv(f, separator="\t") f.seek(0) assert f.read() == b"a\tb\n1\t1\n2\t2\n3\t3\n" + f.seek(0) assert_frame_equal(df, pl.read_csv(f, separator="\t")) @@ -962,6 +963,7 @@ def test_write_csv_line_terminator() -> None: df.write_csv(f, line_terminator="\r\n") f.seek(0) assert f.read() == b"a,b\r\n1,1\r\n2,2\r\n3,3\r\n" + f.seek(0) assert_frame_equal(df, pl.read_csv(f, eol_char="\n")) @@ -996,6 +998,7 @@ def test_quoting_round_trip() -> None: } ) df.write_csv(f) + f.seek(0) read_df = pl.read_csv(f) assert_frame_equal(read_df, df) @@ -1101,7 +1104,7 @@ def test_csv_string_escaping() -> None: assert_frame_equal(df_read, df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_glob_csv(df_no_lists: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -1183,6 +1186,7 @@ def test_csv_write_escape_headers() -> None: out = io.BytesIO() df1.write_csv(out) + out.seek(0) df2 = pl.read_csv(out) assert_frame_equal(df1, df2) assert df2.schema == {"c,o,l,u,m,n": pl.Int64} @@ -1639,7 +1643,7 @@ def test_csv_statistics_offset() -> None: assert pl.read_csv(io.StringIO(csv), n_rows=N).height == 4999 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_csv_scan_categorical(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -1653,7 +1657,7 @@ def test_csv_scan_categorical(tmp_path: Path) -> None: assert result["x"].dtype == pl.Categorical -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_csv_scan_new_columns_less_than_original_columns(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -1694,14 +1698,14 @@ def test_read_empty_csv(io_files_path: Path) -> None: assert_frame_equal(df, pl.DataFrame()) -@pytest.mark.slow() +@pytest.mark.slow def test_read_web_file() -> None: url = "https://raw.githubusercontent.com/pola-rs/polars/main/examples/datasets/foods1.csv" df = pl.read_csv(url) assert df.shape == (27, 4) -@pytest.mark.slow() +@pytest.mark.slow def test_csv_multiline_splits() -> None: # create a very unlikely csv file with many multilines in a # single field (e.g. 5000). polars must reject multi-threading here @@ -2008,7 +2012,7 @@ def test_invalid_csv_raise() -> None: ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_partial_read_compressed_file( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -2066,8 +2070,8 @@ def test_csv_invalid_escape_utf8_14960() -> None: pl.read_csv('col1\n""•'.encode()) -@pytest.mark.slow() -@pytest.mark.write_disk() +@pytest.mark.slow +@pytest.mark.write_disk def test_read_csv_only_loads_selected_columns( memory_usage_without_pyarrow: MemoryUsage, tmp_path: Path, @@ -2133,7 +2137,7 @@ def test_csv_escape_cf_15349() -> None: assert f.read() == b'test\nnormal\n"with\rcr"\n' -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("streaming", [True, False]) def test_skip_rows_after_header(tmp_path: Path, streaming: bool) -> None: tmp_path.mkdir(exist_ok=True) @@ -2233,7 +2237,7 @@ def test_projection_applied_on_file_with_no_rows_16606(tmp_path: Path) -> None: assert out == columns -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_csv_to_dangling_file_17328( df_no_lists: pl.DataFrame, tmp_path: Path ) -> None: @@ -2249,7 +2253,7 @@ def test_write_csv_raise_on_non_utf8_17328( df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk")) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_csv_appending_17543(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) df = pl.DataFrame({"col": ["value"]}) @@ -2279,4 +2283,5 @@ def test_read_csv_cast_unparsable_later( ) -> None: f = io.BytesIO() df.write_csv(f) + f.seek(0) assert df.equals(pl.read_csv(f, schema={"x": dtype})) diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py index b7be5ab08061..33c6b052ffdf 100644 --- a/py-polars/tests/unit/io/test_delta.py +++ b/py-polars/tests/unit/io/test_delta.py @@ -8,14 +8,14 @@ import pyarrow.fs import pytest from deltalake import DeltaTable -from deltalake.exceptions import TableNotFoundError +from deltalake.exceptions import DeltaError, TableNotFoundError from deltalake.table import TableMerger import polars as pl from polars.testing import assert_frame_equal, assert_frame_not_equal -@pytest.fixture() +@pytest.fixture def delta_table_path(io_files_path: Path) -> Path: return io_files_path / "delta-table" @@ -34,7 +34,7 @@ def test_scan_delta_version(delta_table_path: Path) -> None: assert_frame_not_equal(df1, df2) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_delta_timestamp_version(tmp_path: Path) -> None: df_sample = pl.DataFrame({"name": ["Joey"], "age": [14]}) df_sample.write_delta(tmp_path, mode="append") @@ -107,7 +107,7 @@ def test_read_delta_version(delta_table_path: Path) -> None: assert_frame_not_equal(df1, df2) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_read_delta_timestamp_version(tmp_path: Path) -> None: df_sample = pl.DataFrame({"name": ["Joey"], "age": [14]}) df_sample.write_delta(tmp_path, mode="append") @@ -163,7 +163,7 @@ def test_read_delta_relative(delta_table_path: Path) -> None: assert_frame_equal(expected, df, check_dtypes=False) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_delta(df: pl.DataFrame, tmp_path: Path) -> None: v0 = df.select(pl.col(pl.String)) v1 = df.select(pl.col(pl.Int64)) @@ -173,8 +173,8 @@ def test_write_delta(df: pl.DataFrame, tmp_path: Path) -> None: v0.write_delta(tmp_path) # Case: Error if table exists - with pytest.raises(ValueError): - v1.write_delta(tmp_path) + with pytest.raises(DeltaError, match="A table already exists"): + v0.write_delta(tmp_path) # Case: Overwrite with new version (version 1) v1.write_delta( @@ -245,18 +245,18 @@ def test_write_delta(df: pl.DataFrame, tmp_path: Path) -> None: df_supported.write_delta(partitioned_tbl_uri, mode="overwrite") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_delta_overwrite_schema_deprecated( df: pl.DataFrame, tmp_path: Path ) -> None: df = df.select(pl.col(pl.Int64)) with pytest.deprecated_call(): - df.write_delta(tmp_path, overwrite_schema=True) + df.write_delta(tmp_path, mode="overwrite", overwrite_schema=True) result = pl.read_delta(str(tmp_path)) assert_frame_equal(df, result) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( "series", [ @@ -410,7 +410,7 @@ def test_write_delta_w_compatible_schema(series: pl.Series, tmp_path: Path) -> N assert tbl.version() == 1 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_delta_with_schema_10540(tmp_path: Path) -> None: df = pl.DataFrame({"a": [1, 2, 3]}) @@ -418,7 +418,7 @@ def test_write_delta_with_schema_10540(tmp_path: Path) -> None: df.write_delta(tmp_path, delta_write_options={"schema": pa_schema}) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( "expr", [ @@ -455,7 +455,7 @@ def test_write_delta_with_merge_and_no_table(tmp_path: Path) -> None: ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_delta_with_merge(tmp_path: Path) -> None: df = pl.DataFrame({"a": [1, 2, 3]}) @@ -472,9 +472,8 @@ def test_write_delta_with_merge(tmp_path: Path) -> None: ) assert isinstance(merger, TableMerger) - assert merger.predicate == "s.a = t.a" - assert merger.source_alias == "s" - assert merger.target_alias == "t" + assert merger._builder.source_alias == "s" + assert merger._builder.target_alias == "t" merger.when_matched_delete(predicate="t.a > 2").execute() @@ -484,7 +483,7 @@ def test_write_delta_with_merge(tmp_path: Path) -> None: assert_frame_equal(result, expected, check_row_order=False) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_unsupported_dtypes(tmp_path: Path) -> None: df = pl.DataFrame({"a": [None]}, schema={"a": pl.Null}) with pytest.raises(TypeError, match="unsupported data type"): @@ -495,7 +494,7 @@ def test_unsupported_dtypes(tmp_path: Path) -> None: df.write_delta(tmp_path / "time") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_categorical_becomes_string(tmp_path: Path) -> None: df = pl.DataFrame({"a": ["A", "B", "A"]}, schema={"a": pl.Categorical}) df.write_delta(tmp_path) @@ -503,7 +502,7 @@ def test_categorical_becomes_string(tmp_path: Path) -> None: assert_frame_equal(df2, pl.DataFrame({"a": ["A", "B", "A"]}, schema={"a": pl.Utf8})) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("rechunk_and_expected_chunks", [(True, 1), (False, 3)]) def test_read_parquet_respects_rechunk_16982( rechunk_and_expected_chunks: tuple[bool, int], tmp_path: Path diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index c677b24e6e07..ad285b82f3b3 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -71,7 +71,7 @@ def impl_test_hive_partitioned_predicate_pushdown( @pytest.mark.xdist_group("streaming") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partitioned_predicate_pushdown( io_files_path: Path, tmp_path: Path, @@ -87,7 +87,7 @@ def test_hive_partitioned_predicate_pushdown( @pytest.mark.xdist_group("streaming") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partitioned_predicate_pushdown_single_threaded_async_17155( io_files_path: Path, tmp_path: Path, @@ -105,7 +105,7 @@ def test_hive_partitioned_predicate_pushdown_single_threaded_async_17155( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files( io_files_path: Path, tmp_path: Path, monkeypatch: Any, capfd: Any ) -> None: @@ -136,7 +136,7 @@ def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files( @pytest.mark.xdist_group("streaming") -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("streaming", [True, False]) def test_hive_partitioned_slice_pushdown( io_files_path: Path, tmp_path: Path, streaming: bool @@ -170,7 +170,7 @@ def test_hive_partitioned_slice_pushdown( @pytest.mark.xdist_group("streaming") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partitioned_projection_pushdown( io_files_path: Path, tmp_path: Path ) -> None: @@ -207,7 +207,7 @@ def test_hive_partitioned_projection_pushdown( assert_frame_equal(result, expected) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partitioned_projection_skip_files( io_files_path: Path, tmp_path: Path ) -> None: @@ -232,7 +232,7 @@ def test_hive_partitioned_projection_skip_files( assert_frame_equal(df, test_df) -@pytest.fixture() +@pytest.fixture def dataset_path(tmp_path: Path) -> Path: tmp_path.mkdir(exist_ok=True) @@ -253,7 +253,7 @@ def dataset_path(tmp_path: Path) -> Path: return root -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_parquet_hive_schema(dataset_path: Path) -> None: result = pl.scan_parquet(dataset_path / "**/*.parquet", hive_partitioning=True) assert result.collect_schema() == OrderedDict( @@ -271,7 +271,7 @@ def test_scan_parquet_hive_schema(dataset_path: Path) -> None: assert result.collect().schema == expected_schema -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_read_parquet_invalid_hive_schema(dataset_path: Path) -> None: with pytest.raises( SchemaFieldNotFoundError, @@ -467,7 +467,7 @@ def test_hive_partition_schema_inference(tmp_path: Path) -> None: assert_series_equal(out["a"], expected[i]) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partition_force_async_17155(tmp_path: Path, monkeypatch: Any) -> None: monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") monkeypatch.setenv("POLARS_PREFETCH_SIZE", "1") @@ -502,7 +502,7 @@ def test_hive_partition_force_async_17155(tmp_path: Path, monkeypatch: Any) -> N (pl.scan_ipc, pl.DataFrame.write_ipc), ], ) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("projection_pushdown", [True, False]) def test_hive_partition_columns_contained_in_file( tmp_path: Path, @@ -555,7 +555,7 @@ def assert_with_projections(lf: pl.LazyFrame, df: pl.DataFrame) -> None: assert_with_projections(lf, rhs) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_partition_dates(tmp_path: Path) -> None: df = pl.DataFrame( { @@ -631,7 +631,7 @@ def test_hive_partition_dates(tmp_path: Path) -> None: (pl.scan_ipc, pl.DataFrame.write_ipc), ], ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_projection_only_hive_parts_gives_correct_number_of_rows( tmp_path: Path, scan_func: Callable[[Any], pl.LazyFrame], @@ -669,7 +669,7 @@ def test_projection_only_hive_parts_gives_correct_number_of_rows( ), ], ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None: root = tmp_path df.write_parquet(root, partition_by=["a", "b"]) @@ -681,8 +681,8 @@ def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None: assert_frame_equal(lf.collect(), df.with_columns(pl.col("a", "b").cast(pl.String))) -@pytest.mark.slow() -@pytest.mark.write_disk() +@pytest.mark.slow +@pytest.mark.write_disk def test_hive_write_multiple_files(tmp_path: Path) -> None: chunk_size = 262_144 n_rows = 100_000 @@ -699,7 +699,7 @@ def test_hive_write_multiple_files(tmp_path: Path) -> None: assert_frame_equal(pl.scan_parquet(root).collect(), df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_write_dates(tmp_path: Path) -> None: df = pl.DataFrame( { @@ -733,7 +733,7 @@ def test_hive_write_dates(tmp_path: Path) -> None: ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_hive_predicate_dates_14712( tmp_path: Path, monkeypatch: Any, capfd: Any ) -> None: diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py index 0d80816552f5..5a5f6769e3a5 100644 --- a/py-polars/tests/unit/io/test_iceberg.py +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -12,7 +12,7 @@ from polars.io.iceberg import _convert_predicate, _to_ast -@pytest.fixture() +@pytest.fixture def iceberg_path(io_files_path: Path) -> str: # Iceberg requires absolute paths, so we'll symlink # the test table into /tmp/iceberg/t1/ @@ -26,12 +26,12 @@ def iceberg_path(io_files_path: Path) -> str: return f"file://{iceberg_path.resolve()}" -@pytest.mark.slow() -@pytest.mark.write_disk() +@pytest.mark.slow +@pytest.mark.write_disk @pytest.mark.filterwarnings( "ignore:No preferred file implementation for scheme*:UserWarning" ) -@pytest.mark.ci_only() +@pytest.mark.ci_only class TestIcebergScanIO: """Test coverage for `iceberg` scan ops.""" @@ -88,7 +88,7 @@ def test_scan_iceberg_filter_on_column(self, iceberg_path: str) -> None: ] -@pytest.mark.ci_only() +@pytest.mark.ci_only class TestIcebergExpressions: """Test coverage for `iceberg` expressions comprehension.""" diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index 1c0b7ed4516c..dd60d0ae209c 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -1,8 +1,6 @@ from __future__ import annotations import io -import os -import re from decimal import Decimal from typing import TYPE_CHECKING, Any @@ -10,7 +8,6 @@ import pytest import polars as pl -from polars.exceptions import ComputeError from polars.interchange.protocol import CompatLevel from polars.testing import assert_frame_equal @@ -44,11 +41,13 @@ def test_from_to_buffer( ) -> None: # use an ad-hoc buffer (file=None) buf1 = write_ipc(df, stream, None, compression=compression) + buf1.seek(0) read_df = read_ipc(stream, buf1, use_pyarrow=False) assert_frame_equal(df, read_df, categorical_as_str=True) # explicitly supply an existing buffer buf2 = io.BytesIO() + buf2.seek(0) write_ipc(df, stream, buf2, compression=compression) buf2.seek(0) read_df = read_ipc(stream, buf2, use_pyarrow=False) @@ -58,7 +57,7 @@ def test_from_to_buffer( @pytest.mark.parametrize("compression", COMPRESSIONS) @pytest.mark.parametrize("path_as_string", [True, False]) @pytest.mark.parametrize("stream", [True, False]) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_from_to_file( df: pl.DataFrame, compression: IpcCompression, @@ -77,7 +76,7 @@ def test_from_to_file( @pytest.mark.parametrize("stream", [True, False]) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_select_columns_from_file( df: pl.DataFrame, tmp_path: Path, stream: bool ) -> None: @@ -153,7 +152,7 @@ def test_ipc_schema(compression: IpcCompression) -> None: assert pl.read_ipc_schema(f) == expected -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("compression", COMPRESSIONS) @pytest.mark.parametrize("path_as_string", [True, False]) def test_ipc_schema_from_file( @@ -208,7 +207,7 @@ def test_ipc_column_order(stream: bool) -> None: assert read_ipc(stream, f, columns=columns).columns == columns -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_glob_ipc(df: pl.DataFrame, tmp_path: Path) -> None: file_path = tmp_path / "small.ipc" df.write_ipc(file_path) @@ -231,7 +230,7 @@ def test_from_float16() -> None: assert pl.read_ipc(f, use_pyarrow=False).dtypes == [pl.Float32] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_binview_ipc_mmap(tmp_path: Path) -> None: df = pl.DataFrame({"foo": ["aa" * 10, "bb", None, "small", "big" * 20]}) file_path = tmp_path / "dump.ipc" @@ -245,6 +244,7 @@ def test_list_nested_enum() -> None: df = pl.DataFrame(pl.Series("list_cat", [["a", "b", "c", None]], dtype=dtype)) buffer = io.BytesIO() df.write_ipc(buffer, compat_level=CompatLevel.newest()) + buffer.seek(0) df = pl.read_ipc(buffer) assert df.get_column("list_cat").dtype == dtype @@ -258,11 +258,12 @@ def test_struct_nested_enum() -> None: ) buffer = io.BytesIO() df.write_ipc(buffer, compat_level=CompatLevel.newest()) + buffer.seek(0) df = pl.read_ipc(buffer) assert df.get_column("struct_cat").dtype == dtype -@pytest.mark.slow() +@pytest.mark.slow def test_ipc_view_gc_14448() -> None: f = io.BytesIO() # This size was required to trigger the bug @@ -274,8 +275,8 @@ def test_ipc_view_gc_14448() -> None: assert_frame_equal(pl.read_ipc(f), df) -@pytest.mark.slow() -@pytest.mark.write_disk() +@pytest.mark.slow +@pytest.mark.write_disk @pytest.mark.parametrize("stream", [True, False]) def test_read_ipc_only_loads_selected_columns( memory_usage_without_pyarrow: MemoryUsage, @@ -313,7 +314,7 @@ def test_read_ipc_only_loads_selected_columns( assert 16_000_000 < memory_usage_without_pyarrow.get_peak() < 23_000_000 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_ipc_decimal_15920( tmp_path: Path, ) -> None: @@ -339,29 +340,3 @@ def test_ipc_decimal_15920( path = f"{tmp_path}/data" df.write_ipc(path) assert_frame_equal(pl.read_ipc(path), df) - - -@pytest.mark.write_disk() -def test_ipc_raise_on_writing_mmap(tmp_path: Path) -> None: - p = tmp_path / "foo.ipc" - df = pl.DataFrame({"foo": [1, 2, 3]}) - # first write is allowed - df.write_ipc(p) - - # now open as memory mapped - df = pl.read_ipc(p, memory_map=True) - - if os.name == "nt": - # In Windows, it's the duty of the system to ensure exclusive access - with pytest.raises( - OSError, - match=re.escape( - "The requested operation cannot be performed on a file with a user-mapped section open. (os error 1224)" - ), - ): - df.write_ipc(p) - else: - with pytest.raises( - ComputeError, match="cannot write to file: already memory mapped" - ): - df.write_ipc(p) diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index 13e459294bd7..4bce4ee4e0ce 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -306,7 +306,7 @@ def test_ndjson_null_inference_13183() -> None: } -@pytest.mark.write_disk() +@pytest.mark.write_disk @typing.no_type_check def test_json_wrong_input_handle_textio(tmp_path: Path) -> None: # this shouldn't be passed, but still we test if we can handle it gracefully @@ -375,3 +375,13 @@ def test_json_normalize() -> None: "fitness.height": [130, 130, 130], "fitness.weight": [60, 60, 60], } + + +def test_empty_json() -> None: + df = pl.read_json(io.StringIO("{}")) + assert df.shape == (0, 0) + assert isinstance(df, pl.DataFrame) + + df = pl.read_json(b'{"j":{}}') + assert df.dtypes == [pl.Struct([])] + assert df.shape == (0, 1) diff --git a/py-polars/tests/unit/io/test_lazy_count_star.py b/py-polars/tests/unit/io/test_lazy_count_star.py index fa06b4695cae..a2c03596dd15 100644 --- a/py-polars/tests/unit/io/test_lazy_count_star.py +++ b/py-polars/tests/unit/io/test_lazy_count_star.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from pathlib import Path +import gzip from tempfile import NamedTemporaryFile import pytest @@ -22,11 +23,11 @@ def test_count_csv(io_files_path: Path, path: str, n_rows: int) -> None: expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) # Check if we are using our fast count star - assert "FAST COUNT(*)" in lf.explain() + assert "FAST COUNT" in lf.explain() assert_frame_equal(lf.collect(), expected) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_commented_csv() -> None: csv_a = NamedTemporaryFile() csv_a.write( @@ -41,7 +42,7 @@ def test_commented_csv() -> None: expected = pl.DataFrame(pl.Series("len", [2], dtype=pl.UInt32)) lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len()) - assert "FAST COUNT(*)" in lf.explain() + assert "FAST COUNT" in lf.explain() assert_frame_equal(lf.collect(), expected) @@ -54,7 +55,7 @@ def test_count_parquet(io_files_path: Path, pattern: str, n_rows: int) -> None: expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) # Check if we are using our fast count star - assert "FAST COUNT(*)" in lf.explain() + assert "FAST COUNT" in lf.explain() assert_frame_equal(lf.collect(), expected) @@ -67,7 +68,7 @@ def test_count_ipc(io_files_path: Path, path: str, n_rows: int) -> None: expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) # Check if we are using our fast count star - assert "FAST COUNT(*)" in lf.explain() + assert "FAST COUNT" in lf.explain() assert_frame_equal(lf.collect(), expected) @@ -80,5 +81,32 @@ def test_count_ndjson(io_files_path: Path, path: str, n_rows: int) -> None: expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) # Check if we are using our fast count star - assert "FAST COUNT(*)" in lf.explain() + assert "FAST COUNT" in lf.explain() assert_frame_equal(lf.collect(), expected) + + +def test_count_compressed_csv_18057(io_files_path: Path) -> None: + csv_file = io_files_path / "gzipped.csv.gz" + + expected = pl.DataFrame( + {"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]} + ) + lf = pl.scan_csv(csv_file, truncate_ragged_lines=True) + out = lf.collect() + assert_frame_equal(out, expected) + # This also tests: + # #18070 "CSV count_rows does not skip empty lines at file start" + # as the file has an empty line at the beginning. + assert lf.select(pl.len()).collect().item() == 3 + + +def test_count_compressed_ndjson(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + path = tmp_path / "data.jsonl.gz" + df = pl.DataFrame({"x": range(5)}) + + with gzip.open(path, "wb") as f: + df.write_ndjson(f) + + lf = pl.scan_ndjson(path) + assert lf.select(pl.len()).collect().item() == 5 diff --git a/py-polars/tests/unit/io/test_lazy_csv.py b/py-polars/tests/unit/io/test_lazy_csv.py index c2351ec109bc..aee85eb53a0d 100644 --- a/py-polars/tests/unit/io/test_lazy_csv.py +++ b/py-polars/tests/unit/io/test_lazy_csv.py @@ -12,7 +12,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def foods_file_path(io_files_path: Path) -> Path: return io_files_path / "foods1.csv" @@ -36,7 +36,7 @@ def test_scan_empty_csv(io_files_path: Path) -> None: assert_frame_equal(lf, pl.LazyFrame()) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_invalid_utf8(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -206,7 +206,7 @@ def test_lazy_row_index_no_push_down(foods_file_path: Path) -> None: assert 'SELECTION: [(col("category")) == (String(vegetables))]' in plan -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_glob_skip_rows(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -276,7 +276,7 @@ def test_scan_csv_slice_offset_zero(io_files_path: Path) -> None: assert result.collect().height == 4 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "small.parquet" @@ -287,7 +287,7 @@ def test_scan_empty_csv_with_row_index(tmp_path: Path) -> None: assert read.collect().schema == OrderedDict([("idx", pl.UInt32), ("a", pl.String)]) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_csv_null_values_with_projection_15515() -> None: data = """IndCode,SireCode,BirthDate,Flag ID00316,.,19940315, @@ -309,7 +309,7 @@ def test_csv_null_values_with_projection_15515() -> None: } -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_csv_respect_user_schema_ragged_lines_15254() -> None: with tempfile.NamedTemporaryFile() as f: f.write( @@ -438,18 +438,3 @@ def test_scan_csv_with_column_names_nonexistent_file() -> None: # Upon collection, it should fail with pytest.raises(FileNotFoundError): result.collect() - - -def test_scan_csv_compressed_row_count_18057(io_files_path: Path) -> None: - csv_file = io_files_path / "gzipped.csv.gz" - - expected = pl.DataFrame( - {"a": [1, 2, 3], "b": ["a", "b", "c"], "c": [1.0, 2.0, 3.0]} - ) - lf = pl.scan_csv(csv_file, truncate_ragged_lines=True) - out = lf.collect() - assert_frame_equal(out, expected) - # This also tests: - # #18070 "CSV count_rows does not skip empty lines at file start" - # as the file has an empty line at the beginning. - assert lf.select(pl.len()).collect().item() == 3 diff --git a/py-polars/tests/unit/io/test_lazy_ipc.py b/py-polars/tests/unit/io/test_lazy_ipc.py index 354a912ea4b6..0d67b6b06f89 100644 --- a/py-polars/tests/unit/io/test_lazy_ipc.py +++ b/py-polars/tests/unit/io/test_lazy_ipc.py @@ -11,7 +11,7 @@ from pathlib import Path -@pytest.fixture() +@pytest.fixture def foods_ipc_path(io_files_path: Path) -> Path: return io_files_path / "foods1.ipc" diff --git a/py-polars/tests/unit/io/test_lazy_json.py b/py-polars/tests/unit/io/test_lazy_json.py index 3ecd6f538f0a..a5a53d78f94c 100644 --- a/py-polars/tests/unit/io/test_lazy_json.py +++ b/py-polars/tests/unit/io/test_lazy_json.py @@ -11,7 +11,7 @@ from pathlib import Path -@pytest.fixture() +@pytest.fixture def foods_ndjson_path(io_files_path: Path) -> Path: return io_files_path / "foods1.ndjson" @@ -66,7 +66,7 @@ def test_scan_ndjson_batch_size_zero() -> None: pl.scan_ndjson("test.ndjson", batch_size=0) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_projection(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 7596322d8748..e21844003375 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -16,12 +16,12 @@ from polars._typing import ParallelStrategy -@pytest.fixture() +@pytest.fixture def parquet_file_path(io_files_path: Path) -> Path: return io_files_path / "small.parquet" -@pytest.fixture() +@pytest.fixture def foods_parquet_path(io_files_path: Path) -> Path: return io_files_path / "foods1.parquet" @@ -65,7 +65,7 @@ def test_row_index_len_16543(foods_parquet_path: Path) -> None: assert q.select(pl.all()).select(pl.len()).collect().item() == 27 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_categorical_parquet_statistics(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -104,7 +104,7 @@ def test_categorical_parquet_statistics(tmp_path: Path) -> None: assert df.shape == (4, 3) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_eq_stats(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -125,7 +125,7 @@ def test_parquet_eq_stats(tmp_path: Path) -> None: ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_is_in_stats(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -165,7 +165,7 @@ def test_parquet_is_in_stats(tmp_path: Path) -> None: ).collect().shape == (8, 1) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_stats(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -208,7 +208,7 @@ def test_row_index_schema_parquet(parquet_file_path: Path) -> None: ).dtypes == [pl.UInt32, pl.String] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -244,7 +244,7 @@ def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -280,7 +280,7 @@ def test_parquet_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> Non ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_categorical(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -325,7 +325,7 @@ def test_glob_n_rows(io_files_path: Path) -> None: } -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_statistics_filter_9925(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "codes.parquet" @@ -338,7 +338,7 @@ def test_parquet_statistics_filter_9925(tmp_path: Path) -> None: assert q.collect().to_dict(as_series=False) == {"code": [300964, 300972, 26]} -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_statistics_filter_11069(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "foo.parquet" @@ -359,7 +359,7 @@ def test_parquet_list_arg(io_files_path: Path) -> None: assert df.row(0) == ("vegetables", 45, 0.5, 2) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_many_row_groups_12297(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "foo.parquet" @@ -368,7 +368,7 @@ def test_parquet_many_row_groups_12297(tmp_path: Path) -> None: assert_frame_equal(pl.scan_parquet(file_path).collect(), df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_row_index_empty_file(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "test.parquet" @@ -378,7 +378,7 @@ def test_row_index_empty_file(tmp_path: Path) -> None: assert result.schema == OrderedDict([("idx", pl.UInt32), ("a", pl.Float32)]) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_io_struct_async_12500(tmp_path: Path) -> None: file_path = tmp_path / "test.parquet" pl.DataFrame( @@ -392,7 +392,7 @@ def test_io_struct_async_12500(tmp_path: Path) -> None: ) == {"c1": [{"a": "foo", "b": "bar"}]} -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("streaming", [True, False]) def test_parquet_different_schema(tmp_path: Path, streaming: bool) -> None: # Schema is different but the projected columns are same dtype. @@ -409,7 +409,7 @@ def test_parquet_different_schema(tmp_path: Path, streaming: bool) -> None: ).columns == ["b"] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_nested_slice_12480(tmp_path: Path) -> None: path = tmp_path / "data.parquet" df = pl.select(pl.lit(1).repeat_by(10_000).explode().cast(pl.List(pl.Int32))) @@ -419,7 +419,7 @@ def test_nested_slice_12480(tmp_path: Path) -> None: assert pl.scan_parquet(path).slice(0, 1).collect().height == 1 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_deadlock_rayon_spawn_from_async_15172( monkeypatch: Any, tmp_path: Path ) -> None: @@ -443,7 +443,7 @@ def scan_collect() -> None: assert results[0].equals(df) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("streaming", [True, False]) def test_parquet_schema_mismatch_panic_17067(tmp_path: Path, streaming: bool) -> None: pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).write_parquet(tmp_path / "1.parquet") @@ -453,7 +453,7 @@ def test_parquet_schema_mismatch_panic_17067(tmp_path: Path, streaming: bool) -> pl.scan_parquet(tmp_path).collect(streaming=streaming) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_predicate_push_down_categorical_17744(tmp_path: Path) -> None: path = tmp_path / "1" @@ -504,11 +504,15 @@ def trim_to_metadata(path: str | Path) -> None: assert pl.read_parquet_schema(paths[0]) == dfs[0].schema # * Attempting to read any data will error with pytest.raises(ComputeError): - pl.scan_parquet(paths[0]).collect() + pl.scan_parquet(paths[0]).collect(streaming=streaming) df = dfs[1] - assert_frame_equal(pl.scan_parquet(paths).slice(1, 1).collect(), df) - assert_frame_equal(pl.scan_parquet(paths[1:]).head(1).collect(), df) + assert_frame_equal( + pl.scan_parquet(paths).slice(1, 1).collect(streaming=streaming), df + ) + assert_frame_equal( + pl.scan_parquet(paths[1:]).head(1).collect(streaming=streaming), df + ) # Negative slice unsupported in streaming if not streaming: diff --git a/py-polars/tests/unit/io/test_other.py b/py-polars/tests/unit/io/test_other.py index d6562911adbb..4c08250838d8 100644 --- a/py-polars/tests/unit/io/test_other.py +++ b/py-polars/tests/unit/io/test_other.py @@ -124,7 +124,7 @@ def test_unit_io_subdir_has_no_init() -> None: ).exists(), "Found undesirable '__init__.py' in the 'unit.io' tests subdirectory" -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ("scan_funcs", "write_func"), [ diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 18f15bfd40c0..3da465561bd1 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -3,7 +3,7 @@ import io from datetime import datetime, time, timezone from decimal import Decimal -from typing import TYPE_CHECKING, Any, cast +from typing import IO, TYPE_CHECKING, Any, Literal, cast import fsspec import numpy as np @@ -12,12 +12,13 @@ import pyarrow.dataset as ds import pyarrow.parquet as pq import pytest -from hypothesis import HealthCheck, given, settings +from hypothesis import given +from hypothesis import strategies as st import polars as pl from polars.exceptions import ComputeError from polars.testing import assert_frame_equal, assert_series_equal -from polars.testing.parametric import dataframes +from polars.testing.parametric import column, dataframes if TYPE_CHECKING: from pathlib import Path @@ -33,12 +34,12 @@ def test_round_trip(df: pl.DataFrame) -> None: assert_frame_equal(pl.read_parquet(f), df) -def test_scan_round_trip(tmp_path: Path, df: pl.DataFrame) -> None: - tmp_path.mkdir(exist_ok=True) - f = tmp_path / "test.parquet" - +def test_scan_round_trip(df: pl.DataFrame) -> None: + f = io.BytesIO() df.write_parquet(f) + f.seek(0) assert_frame_equal(pl.scan_parquet(f).collect(), df) + f.seek(0) assert_frame_equal(pl.scan_parquet(f).head().collect(), df.head()) @@ -53,7 +54,7 @@ def test_scan_round_trip(tmp_path: Path, df: pl.DataFrame) -> None: ] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_parquet_using_pyarrow_9753(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -93,7 +94,7 @@ def test_write_parquet_using_pyarrow_write_to_dataset_with_partitioning( assert_frame_equal(df, read_df) -@pytest.fixture() +@pytest.fixture def small_parquet_path(io_files_path: Path) -> Path: return io_files_path / "small.parquet" @@ -148,7 +149,7 @@ def test_to_from_buffer_lzo(df: pl.DataFrame) -> None: _ = pl.read_parquet(buf) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("compression", COMPRESSIONS) def test_to_from_file( df: pl.DataFrame, compression: ParquetCompression, tmp_path: Path @@ -161,7 +162,7 @@ def test_to_from_file( assert_frame_equal(df, read_df, categorical_as_str=True) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_to_from_file_lzo(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -245,7 +246,7 @@ def test_nested_parquet() -> None: assert isinstance(read.dtypes[0].inner, pl.datatypes.Struct) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_glob_parquet(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "small.parquet" @@ -278,7 +279,7 @@ def test_chunked_round_trip() -> None: assert_frame_equal(pl.read_parquet(f), df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_lazy_self_join_file_cache_prop_3979(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -430,7 +431,7 @@ def test_parquet_nested_dictionaries_6217() -> None: assert_frame_equal(read, df) # type: ignore[arg-type] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_fetch_union(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -456,7 +457,7 @@ def test_fetch_union(tmp_path: Path) -> None: assert_frame_equal(result_glob, expected) -@pytest.mark.slow() +@pytest.mark.slow def test_struct_pyarrow_dataset_5796(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -471,7 +472,7 @@ def test_struct_pyarrow_dataset_5796(tmp_path: Path) -> None: assert_frame_equal(result, df) # type: ignore[arg-type] -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("case", [1048576, 1048577]) def test_parquet_chunks_545(case: int) -> None: f = io.BytesIO() @@ -584,7 +585,7 @@ def test_nested_struct_read_12610() -> None: assert_frame_equal(expect, actual) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_decimal_parquet(tmp_path: Path) -> None: path = tmp_path / "foo.parquet" df = pl.DataFrame( @@ -601,7 +602,7 @@ def test_decimal_parquet(tmp_path: Path) -> None: assert out == {"foo": [2], "bar": [Decimal("7")]} -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_enum_parquet(tmp_path: Path) -> None: path = tmp_path / "enum.parquet" df = pl.DataFrame( @@ -629,7 +630,7 @@ def test_parquet_rle_non_nullable_12814() -> None: assert_frame_equal(expect, actual) -@pytest.mark.slow() +@pytest.mark.slow def test_parquet_12831() -> None: n = 70_000 df = pl.DataFrame({"x": ["aaaaaa"] * n}) @@ -639,7 +640,7 @@ def test_parquet_12831() -> None: assert_frame_equal(pl.from_arrow(pq.read_table(f)), df) # type: ignore[arg-type] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_struct_categorical(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -658,7 +659,7 @@ def test_parquet_struct_categorical(tmp_path: Path) -> None: assert out.to_dict(as_series=False) == {"b": [{"b": "foo", "count": 1}]} -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_null_parquet(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -669,7 +670,7 @@ def test_null_parquet(tmp_path: Path) -> None: assert_frame_equal(out, df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_write_parquet_with_null_col(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -684,7 +685,20 @@ def test_write_parquet_with_null_col(tmp_path: Path) -> None: assert_frame_equal(out, df) -@pytest.mark.write_disk() +@pytest.mark.write_disk +def test_scan_parquet_binary_buffered_reader(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + df = pl.DataFrame({"a": [1, 2, 3]}) + file_path = tmp_path / "test.parquet" + df.write_parquet(file_path) + + with file_path.open("rb") as f: + out = pl.scan_parquet(f).collect() + assert_frame_equal(out, df) + + +@pytest.mark.write_disk def test_read_parquet_binary_buffered_reader(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -697,7 +711,7 @@ def test_read_parquet_binary_buffered_reader(tmp_path: Path) -> None: assert_frame_equal(out, df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_read_parquet_binary_file_io(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -711,7 +725,7 @@ def test_read_parquet_binary_file_io(tmp_path: Path) -> None: # https://github.com/pola-rs/polars/issues/15760 -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_read_parquet_binary_fsspec(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -918,21 +932,20 @@ def test_parquet_array_dtype_nulls() -> None: ), ], ) -@pytest.mark.write_disk() -def test_complex_types(tmp_path: Path, series: list[Any], dtype: pl.DataType) -> None: +def test_complex_types(series: list[Any], dtype: pl.DataType) -> None: xs = pl.Series(series, dtype=dtype) df = pl.DataFrame({"x": xs}) test_round_trip(df) -@pytest.mark.xfail() +@pytest.mark.xfail def test_placeholder_zero_array() -> None: # @TODO: if this does not fail anymore please enable the upper test-cases pl.Series([[]], dtype=pl.Array(pl.Int8, 0)) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_array_statistics(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -948,8 +961,8 @@ def test_parquet_array_statistics(tmp_path: Path) -> None: assert result.to_dict(as_series=False) == {"a": [[4, 5, 6], [7, 8, 9]], "b": [2, 3]} -@pytest.mark.slow() -@pytest.mark.write_disk() +@pytest.mark.slow +@pytest.mark.write_disk def test_read_parquet_only_loads_selected_columns_15098( memory_usage_without_pyarrow: MemoryUsage, tmp_path: Path ) -> None: @@ -979,27 +992,25 @@ def test_read_parquet_only_loads_selected_columns_15098( assert 8_000_000 < memory_usage_without_pyarrow.get_peak() < 13_000_000 -@pytest.mark.release() -@pytest.mark.write_disk() -def test_max_statistic_parquet_writer(tmp_path: Path) -> None: +@pytest.mark.release +def test_max_statistic_parquet_writer() -> None: # this hits the maximal page size # so the row group will be split into multiple pages # the page statistics need to be correctly reduced # for this query to make sense n = 150_000 - tmp_path.mkdir(exist_ok=True) - # int64 is important to hit the page size df = pl.int_range(0, n, eager=True, dtype=pl.Int64).alias("int").to_frame() - f = tmp_path / "tmp.parquet" + f = io.BytesIO() df.write_parquet(f, statistics=True, use_pyarrow=False, row_group_size=n) + f.seek(0) result = pl.scan_parquet(f).filter(pl.col("int") > n - 3).collect() expected = pl.DataFrame({"int": [149998, 149999]}) assert_frame_equal(result, expected) -@pytest.mark.slow() +@pytest.mark.slow def test_hybrid_rle() -> None: # 10_007 elements to test if not a nice multiple of 8 n = 10_007 @@ -1058,8 +1069,8 @@ def test_hybrid_rle() -> None: f = io.BytesIO() df.write_parquet(f) f.seek(0) - for column in pq.ParquetFile(f).metadata.to_dict()["row_groups"][0]["columns"]: - assert "RLE_DICTIONARY" in column["encodings"] + for col in pq.ParquetFile(f).metadata.to_dict()["row_groups"][0]["columns"]: + assert "RLE_DICTIONARY" in col["encodings"] f.seek(0) assert_frame_equal(pl.read_parquet(f), df) @@ -1086,15 +1097,12 @@ def test_hybrid_rle() -> None: max_size=5000, ) ) -@pytest.mark.slow() -@pytest.mark.write_disk() -@settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) -def test_roundtrip_parametric(df: pl.DataFrame, tmp_path: Path) -> None: - # delete if exists - path = tmp_path / "data.parquet" - - df.write_parquet(path) - result = pl.read_parquet(path) +@pytest.mark.slow +def test_roundtrip_parametric(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + result = pl.read_parquet(f) assert_frame_equal(df, result) @@ -1111,7 +1119,7 @@ def test_parquet_statistics_uint64_16683() -> None: assert statistics.max == u64_max -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("nullable", [True, False]) def test_read_byte_stream_split(nullable: bool) -> None: rng = np.random.default_rng(123) @@ -1143,7 +1151,7 @@ def test_read_byte_stream_split(nullable: bool) -> None: assert_frame_equal(read, df) -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("rows_nullable", [True, False]) @pytest.mark.parametrize("item_nullable", [True, False]) def test_read_byte_stream_split_arrays( @@ -1206,18 +1214,14 @@ def test_read_byte_stream_split_arrays( assert_frame_equal(read, df) -@pytest.mark.write_disk() -def test_parquet_nested_null_array_17795(tmp_path: Path) -> None: - filename = tmp_path / "nested_null.parquet" - - pl.DataFrame([{"struct": {"field": None}}]).write_parquet(filename) - pq.read_table(filename) - +def test_parquet_nested_null_array_17795() -> None: + f = io.BytesIO() + pl.DataFrame([{"struct": {"field": None}}]).write_parquet(f) + f.seek(0) + pq.read_table(f) -@pytest.mark.write_disk() -def test_parquet_record_batches_pyarrow_fixed_size_list_16614(tmp_path: Path) -> None: - filename = tmp_path / "a.parquet" +def test_parquet_record_batches_pyarrow_fixed_size_list_16614() -> None: # @NOTE: # The minimum that I could get it to crash which was ~132000, but let's # just do 150000 to be sure. @@ -1227,27 +1231,28 @@ def test_parquet_record_batches_pyarrow_fixed_size_list_16614(tmp_path: Path) -> schema={"x": pl.Array(pl.Float32, 2)}, ) - x.write_parquet(filename) - b = pl.read_parquet(filename, use_pyarrow=True) + f = io.BytesIO() + x.write_parquet(f) + f.seek(0) + b = pl.read_parquet(f, use_pyarrow=True) assert b["x"].shape[0] == n assert_frame_equal(b, x) -@pytest.mark.write_disk() -def test_parquet_list_element_field_name(tmp_path: Path) -> None: - filename = tmp_path / "list.parquet" - +def test_parquet_list_element_field_name() -> None: + f = io.BytesIO() ( pl.DataFrame( { "a": [[1, 2], [1, 1, 1]], }, schema={"a": pl.List(pl.Int64)}, - ).write_parquet(filename, use_pyarrow=False) + ).write_parquet(f, use_pyarrow=False) ) - schema_str = str(pq.read_schema(filename)) + f.seek(0) + schema_str = str(pq.read_schema(f)) assert "" in schema_str assert "child 0, element: int64" in schema_str @@ -1367,8 +1372,7 @@ def test_parquet_high_nested_null_17805( ) -@pytest.mark.write_disk() -def test_struct_plain_encoded_statistics(tmp_path: Path) -> None: +def test_struct_plain_encoded_statistics() -> None: df = pl.DataFrame( { "a": [None, None, None, None, {"x": None, "y": 0}], @@ -1376,17 +1380,26 @@ def test_struct_plain_encoded_statistics(tmp_path: Path) -> None: schema={"a": pl.Struct({"x": pl.Int8, "y": pl.Int8})}, ) - test_scan_round_trip(tmp_path, df) + test_scan_round_trip(df) @given(df=dataframes(min_size=5, excluded_dtypes=[pl.Decimal, pl.Categorical])) -@settings( - max_examples=100, - deadline=None, - suppress_health_check=[HealthCheck.function_scoped_fixture], -) -def test_scan_round_trip_parametric(tmp_path: Path, df: pl.DataFrame) -> None: - test_scan_round_trip(tmp_path, df) +def test_scan_round_trip_parametric(df: pl.DataFrame) -> None: + test_scan_round_trip(df) + + +def test_empty_rg_no_dict_page_18146() -> None: + df = pl.DataFrame( + { + "a": [], + }, + schema={"a": pl.String}, + ) + + f = io.BytesIO() + pq.write_table(df.to_arrow(), f, compression="NONE", use_dictionary=False) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) def test_write_sliced_lists_18069() -> None: @@ -1400,3 +1413,481 @@ def test_write_sliced_lists_18069() -> None: after = pl.read_parquet(f) assert_frame_equal(before, after) + + +def test_null_array_dict_pages_18085() -> None: + test = pd.DataFrame( + [ + {"A": float("NaN"), "B": 3, "C": None}, + {"A": float("NaN"), "B": None, "C": None}, + ] + ) + + f = io.BytesIO() + test.to_parquet(f) + f.seek(0) + pl.read_parquet(f) + + +@given( + df=dataframes( + min_size=1, + max_size=1000, + allowed_dtypes=[ + pl.List, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + ], + ), + row_group_size=st.integers(min_value=10, max_value=1000), +) +def test_delta_encoding_roundtrip(df: pl.DataFrame, row_group_size: int) -> None: + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=False, + column_encoding="DELTA_BINARY_PACKED", + write_statistics=False, + row_group_size=row_group_size, + ) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@given( + df=dataframes(min_size=1, max_size=1000, allowed_dtypes=[pl.String, pl.Binary]), + row_group_size=st.integers(min_value=10, max_value=1000), +) +def test_delta_length_byte_array_encoding_roundtrip( + df: pl.DataFrame, row_group_size: int +) -> None: + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=False, + column_encoding="DELTA_LENGTH_BYTE_ARRAY", + write_statistics=False, + row_group_size=row_group_size, + ) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +@given( + df=dataframes(min_size=1, max_size=1000, allowed_dtypes=[pl.String, pl.Binary]), + row_group_size=st.integers(min_value=10, max_value=1000), +) +def test_delta_strings_encoding_roundtrip( + df: pl.DataFrame, row_group_size: int +) -> None: + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + compression="NONE", + use_dictionary=False, + column_encoding="DELTA_BYTE_ARRAY", + write_statistics=False, + row_group_size=row_group_size, + ) + + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + +EQUALITY_OPERATORS = ["__eq__", "__lt__", "__le__", "__gt__", "__ge__"] +BOOLEAN_OPERATORS = ["__or__", "__and__"] + + +@given( + df=dataframes( + min_size=0, max_size=100, min_cols=2, max_cols=5, allowed_dtypes=[pl.Int32] + ), + first_op=st.sampled_from(EQUALITY_OPERATORS), + second_op=st.sampled_from( + [None] + + [ + (booljoin, eq) + for booljoin in BOOLEAN_OPERATORS + for eq in EQUALITY_OPERATORS + ] + ), + l1=st.integers(min_value=0, max_value=1000), + l2=st.integers(min_value=0, max_value=1000), + r1=st.integers(min_value=0, max_value=1000), + r2=st.integers(min_value=0, max_value=1000), +) +@pytest.mark.parametrize("parallel_st", ["auto", "prefiltered"]) +def test_predicate_filtering( + df: pl.DataFrame, + first_op: str, + second_op: None | tuple[str, str], + l1: int, + l2: int, + r1: int, + r2: int, + parallel_st: Literal["auto", "prefiltered"], +) -> None: + f = io.BytesIO() + df.write_parquet(f, row_group_size=5) + + cols = df.columns + + l1s = cols[l1 % len(cols)] + l2s = cols[l2 % len(cols)] + expr = (getattr(pl.col(l1s), first_op))(pl.col(l2s)) + + if second_op is not None: + r1s = cols[r1 % len(cols)] + r2s = cols[r2 % len(cols)] + expr = getattr(expr, second_op[0])( + (getattr(pl.col(r1s), second_op[1]))(pl.col(r2s)) + ) + + f.seek(0) + result = pl.scan_parquet(f, parallel=parallel_st).filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@given( + df=dataframes( + min_size=1, + max_size=5, + min_cols=1, + max_cols=1, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum], + ), + offset=st.integers(0, 100), + length=st.integers(0, 100), +) +def test_slice_roundtrip(df: pl.DataFrame, offset: int, length: int) -> None: + offset %= df.height + 1 + length %= df.height - offset + 1 + + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + scanned = pl.scan_parquet(f).slice(offset, length).collect() + assert_frame_equal(scanned, df.slice(offset, length)) + + +def test_struct_prefiltered() -> None: + df = pl.DataFrame({"a": {"x": 1, "y": 2}}) + f = io.BytesIO() + df.write_parquet(f) + + f.seek(0) + ( + pl.scan_parquet(f, parallel="prefiltered") + .filter(pl.col("a").struct.field("x") == 1) + .collect() + ) + + +@pytest.mark.parametrize( + "data", + [ + ( + [{"x": ""}, {"x": "0"}], + pa.struct([pa.field("x", pa.string(), nullable=True)]), + ), + ( + [{"x": ""}, {"x": "0"}], + pa.struct([pa.field("x", pa.string(), nullable=False)]), + ), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=False))), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=True))), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=False), 1)), + ([[""], ["0"]], pa.list_(pa.field("item", pa.string(), nullable=True), 1)), + ( + [["", "1"], ["0", "2"]], + pa.list_(pa.field("item", pa.string(), nullable=False), 2), + ), + ( + [["", "1"], ["0", "2"]], + pa.list_(pa.field("item", pa.string(), nullable=True), 2), + ), + ], +) +@pytest.mark.parametrize("nullable", [False, True]) +def test_nested_skip_18303( + data: tuple[list[dict[str, str] | list[str]], pa.DataType], + nullable: bool, +) -> None: + schema = pa.schema([pa.field("a", data[1], nullable=nullable)]) + tb = pa.table({"a": data[0]}, schema=schema) + + f = io.BytesIO() + pq.write_table(tb, f) + + f.seek(0) + scanned = pl.scan_parquet(f).slice(1, 1).collect() + + assert_frame_equal(scanned, pl.DataFrame(tb).slice(1, 1)) + + +def test_nested_span_multiple_pages_18400() -> None: + width = 4100 + df = pl.DataFrame( + [ + pl.Series( + "a", + [ + list(range(width)), + list(range(width)), + ], + pl.Array(pl.Int64, width), + ), + ] + ) + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + data_page_size=1024, + column_encoding={"a": "PLAIN"}, + ) + + f.seek(0) + assert_frame_equal(df.head(1), pl.read_parquet(f, n_rows=1)) + + +@given( + df=dataframes( + min_size=0, + max_size=1000, + min_cols=2, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum, pl.Array], + include_cols=[column("filter_col", pl.Boolean, allow_null=False)], + ), +) +def test_parametric_small_page_mask_filtering(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f, data_page_size=1024) + + expr = pl.col("filter_col") + f.seek(0) + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@pytest.mark.parametrize( + "value", + [ + "abcd", + 0, + 0.0, + False, + ], +) +def test_different_page_validity_across_pages(value: str | int | float | bool) -> None: + df = pl.DataFrame( + { + "a": [None] + [value] * 4000, + } + ) + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + data_page_size=1024, + column_encoding={"a": "PLAIN"}, + ) + + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) + + +@given( + df=dataframes( + min_size=0, + max_size=100, + min_cols=2, + max_cols=5, + allowed_dtypes=[pl.String, pl.Binary], + include_cols=[ + column("filter_col", pl.Int8, st.integers(0, 1), allow_null=False) + ], + ), +) +def test_delta_length_byte_array_prefiltering(df: pl.DataFrame) -> None: + cols = df.columns + + encodings = {col: "DELTA_LENGTH_BYTE_ARRAY" for col in cols} + encodings["filter_col"] = "PLAIN" + + f = io.BytesIO() + pq.write_table( + df.to_arrow(), + f, + use_dictionary=False, + column_encoding=encodings, + ) + + f.seek(0) + expr = pl.col("filter_col") == 0 + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@given( + df=dataframes( + min_size=0, + max_size=10, + min_cols=1, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum], + include_cols=[ + column("filter_col", pl.Int8, st.integers(0, 1), allow_null=False) + ], + ), +) +def test_general_prefiltering(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + + expr = pl.col("filter_col") == 0 + + f.seek(0) + result = pl.scan_parquet(f, parallel="prefiltered").filter(expr).collect() + assert_frame_equal(result, df.filter(expr)) + + +@given( + df=dataframes( + min_size=0, + max_size=10, + min_cols=1, + max_cols=5, + excluded_dtypes=[pl.Decimal, pl.Categorical, pl.Enum], + include_cols=[column("filter_col", pl.Boolean, allow_null=False)], + ), +) +def test_row_index_prefiltering(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + + expr = pl.col("filter_col") + + f.seek(0) + result = ( + pl.scan_parquet( + f, row_index_name="ri", row_index_offset=42, parallel="prefiltered" + ) + .filter(expr) + .collect() + ) + assert_frame_equal(result, df.with_row_index("ri", 42).filter(expr)) + + +def test_empty_parquet() -> None: + f_pd = io.BytesIO() + f_pl = io.BytesIO() + + pd.DataFrame().to_parquet(f_pd) + pl.DataFrame().write_parquet(f_pl) + + f_pd.seek(0) + f_pl.seek(0) + + empty_from_pd = pl.read_parquet(f_pd) + assert empty_from_pd.shape == (0, 0) + + empty_from_pl = pl.read_parquet(f_pl) + assert empty_from_pl.shape == (0, 0) + + +@pytest.mark.parametrize( + "strategy", + ["columns", "row_groups", "prefiltered"], +) +@pytest.mark.write_disk +def test_row_index_projection_pushdown_18463( + tmp_path: Path, strategy: pl.ParallelStrategy +) -> None: + tmp_path.mkdir(exist_ok=True) + f = tmp_path / "test.parquet" + + pl.DataFrame({"A": [1, 4], "B": [2, 5]}).write_parquet(f) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index() + + assert_frame_equal(df.select("index").collect(), df.collect().select("index")) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index("other_idx_name") + + assert_frame_equal( + df.select("other_idx_name").collect(), df.collect().select("other_idx_name") + ) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index(offset=42) + + assert_frame_equal(df.select("index").collect(), df.collect().select("index")) + + df = pl.scan_parquet(f, parallel=strategy).with_row_index() + + assert_frame_equal( + df.select("index").slice(1, 1).collect(), + df.collect().select("index").slice(1, 1), + ) + + +def test_concat_multiple_inmem() -> None: + f = io.BytesIO() + g = io.BytesIO() + + df1 = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["xyz", "abc", "wow"], + } + ) + df2 = pl.DataFrame( + { + "a": [5, 6, 7], + "b": ["a", "few", "entries"], + } + ) + + dfs = pl.concat([df1, df2]) + + df1.write_parquet(f) + df2.write_parquet(g) + + f.seek(0) + g.seek(0) + + items: list[IO[bytes]] = [f, g] + assert_frame_equal(pl.read_parquet(items), dfs) + + f.seek(0) + g.seek(0) + + assert_frame_equal(pl.read_parquet(items, use_pyarrow=True), dfs) + + f.seek(0) + g.seek(0) + + fb = f.read() + gb = g.read() + + assert_frame_equal(pl.read_parquet([fb, gb]), dfs) + assert_frame_equal(pl.read_parquet([fb, gb], use_pyarrow=True), dfs) diff --git a/py-polars/tests/unit/io/test_pyarrow_dataset.py b/py-polars/tests/unit/io/test_pyarrow_dataset.py index aa4bccb14717..01dd9f208355 100644 --- a/py-polars/tests/unit/io/test_pyarrow_dataset.py +++ b/py-polars/tests/unit/io/test_pyarrow_dataset.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Callable import pyarrow.dataset as ds -import pytest import polars as pl from polars.testing import assert_frame_equal @@ -18,16 +17,22 @@ def helper_dataset_test( query: Callable[[pl.LazyFrame], pl.LazyFrame], batch_size: int | None = None, n_expected: int | None = None, + check_predicate_pushdown: bool = False, ) -> None: dset = ds.dataset(file_path, format="ipc") - expected = pl.scan_ipc(file_path).pipe(query).collect() + q = pl.scan_ipc(file_path).pipe(query) + + expected = q.collect() out = pl.scan_pyarrow_dataset(dset, batch_size=batch_size).pipe(query).collect() assert_frame_equal(out, expected) if n_expected is not None: assert len(out) == n_expected + if check_predicate_pushdown: + assert "FILTER" not in q.explain() + -@pytest.mark.write_disk() +# @pytest.mark.write_disk() def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: file_path = tmp_path / "small.ipc" df.write_ipc(file_path) @@ -36,11 +41,13 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: file_path, lambda lf: lf.filter("bools").select("bools", "floats", "date"), n_expected=1, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, lambda lf: lf.filter(~pl.col("bools")).select("bools", "floats", "date"), n_expected=2, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -48,6 +55,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "date" ), n_expected=1, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -55,6 +63,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "date" ), n_expected=2, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -62,6 +71,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: pl.col("int_nulls").is_not_null() == pl.col("bools") ).select("bools", "floats", "date"), n_expected=0, + check_predicate_pushdown=True, ) # this equality on a column with nulls fails as pyarrow has different # handling kleene logic. We leave it for now and document it in the function. @@ -71,6 +81,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "int_nulls" ), n_expected=0, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -78,6 +89,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "int_nulls" ), n_expected=3, + check_predicate_pushdown=True, ) for closed, n_expected in zip(["both", "left", "right", "none"], [3, 2, 2, 1]): @@ -87,6 +99,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: pl.col("int").is_between(1, 3, closed=closed) ).select("bools", "floats", "date"), n_expected=n_expected, + check_predicate_pushdown=True, ) # this predicate is not supported by pyarrow # check if we still do it on our side @@ -104,6 +117,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "date" ), n_expected=1, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -111,6 +125,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: pl.col("datetime") > datetime(1970, 1, 1, second=13) ).select("bools", "floats", "date"), n_expected=1, + check_predicate_pushdown=True, ) # not yet supported in pyarrow helper_dataset_test( @@ -119,6 +134,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "time", "date" ), n_expected=3, + check_predicate_pushdown=True, ) # pushdown is_in helper_dataset_test( @@ -127,6 +143,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "date" ), n_expected=2, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -134,6 +151,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)]) ).select("bools", "floats", "date"), n_expected=2, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -146,6 +164,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: ) ).select("bools", "floats", "date"), n_expected=2, + check_predicate_pushdown=True, ) helper_dataset_test( file_path, @@ -153,6 +172,7 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: "bools", "floats", "date" ), n_expected=3, + check_predicate_pushdown=True, ) # TODO: remove string cache with pl.StringCache(): @@ -179,6 +199,15 @@ def test_pyarrow_dataset_source(df: pl.DataFrame, tmp_path: Path) -> None: n_expected=2, ) + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("bools") & pl.col("int").is_in([1, 2])).select( + "bools", "floats" + ), + n_expected=1, + check_predicate_pushdown=True, + ) + def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None: df0 = pl.DataFrame({"a": [1, 2, 3]}) diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index a1094ec778f0..d710954fcbfe 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from dataclasses import dataclass from functools import partial from math import ceil @@ -190,7 +191,7 @@ def data_file( raise NotImplementedError() -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -205,7 +206,7 @@ def test_scan( assert_frame_equal(df, data_file.df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_limit( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -227,7 +228,7 @@ def test_scan_with_limit( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_filter( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -253,7 +254,7 @@ def test_scan_with_filter( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_filter_and_limit( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -280,7 +281,7 @@ def test_scan_with_filter_and_limit( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_limit_and_filter( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -307,7 +308,7 @@ def test_scan_with_limit_and_filter( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_row_index_and_limit( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -335,7 +336,7 @@ def test_scan_with_row_index_and_limit( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_row_index_and_filter( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -363,7 +364,7 @@ def test_scan_with_row_index_and_filter( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_row_index_limit_and_filter( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -392,7 +393,7 @@ def test_scan_with_row_index_limit_and_filter( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_row_index_projected_out( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -415,7 +416,7 @@ def test_scan_with_row_index_projected_out( assert_frame_equal(df, data_file.df.select(subset)) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_with_row_index_filter_and_limit( capfd: Any, monkeypatch: pytest.MonkeyPatch, data_file: _DataFile, force_async: bool ) -> None: @@ -447,7 +448,7 @@ def test_scan_with_row_index_filter_and_limit( ) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ("scan_func", "write_func"), [ @@ -474,7 +475,7 @@ def test_scan_limit_0_does_not_panic( assert_frame_equal(scan_func(path).head(0).collect(streaming=streaming), df.clear()) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ("scan_func", "write_func"), [ @@ -526,7 +527,7 @@ def test_scan_directory( assert_frame_equal(out, df) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_glob_excludes_directories(tmp_path: Path) -> None: for dir in ["dir1", "dir2", "dir3"]: (tmp_path / dir).mkdir() @@ -546,7 +547,7 @@ def test_scan_glob_excludes_directories(tmp_path: Path) -> None: @pytest.mark.parametrize("file_name", ["a b", "a %25 b"]) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_scan_async_whitespace_in_path( tmp_path: Path, monkeypatch: Any, file_name: str ) -> None: @@ -563,7 +564,7 @@ def test_scan_async_whitespace_in_path( path.unlink() -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_path_expansion_excludes_empty_files_17362(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -575,7 +576,18 @@ def test_path_expansion_excludes_empty_files_17362(tmp_path: Path) -> None: assert_frame_equal(pl.scan_parquet(tmp_path / "*").collect(), df) -@pytest.mark.write_disk() +@pytest.mark.write_disk +def test_path_expansion_empty_directory_does_not_panic(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + with pytest.raises(pl.exceptions.ComputeError): + pl.scan_parquet(tmp_path).collect() + + with pytest.raises(pl.exceptions.ComputeError): + pl.scan_parquet(tmp_path / "**/*").collect() + + +@pytest.mark.write_disk def test_scan_single_dir_differing_file_extensions_raises_17436(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -619,7 +631,7 @@ def test_scan_nonexistent_path(format: str) -> None: result.collect() -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ("scan_func", "write_func"), [ @@ -640,18 +652,14 @@ def test_scan_include_file_name( streaming: bool, ) -> None: tmp_path.mkdir(exist_ok=True) - paths: list[Path] = [] dfs: list[pl.DataFrame] = [] for x in ["1", "2"]: - paths.append(Path(f"{tmp_path}/{x}.bin").absolute()) - dfs.append(pl.DataFrame({"x": x})) - write_func(dfs[-1], paths[-1]) - - df = pl.concat(dfs).with_columns( - pl.Series("path", map(str, paths), dtype=pl.String) - ) + path = Path(f"{tmp_path}/{x}.bin").absolute() + dfs.append(pl.DataFrame({"x": 10 * [x]}).with_columns(path=pl.lit(str(path)))) + write_func(dfs[-1].drop("path"), path) + df = pl.concat(dfs) assert df.columns == ["x", "path"] with pytest.raises( @@ -686,7 +694,7 @@ def test_scan_include_file_name( assert_frame_equal(lf.head(0).collect(streaming=streaming), df.head(0)) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_async_path_expansion_bracket_17629(tmp_path: Path) -> None: path = tmp_path / "data.parquet" @@ -694,3 +702,86 @@ def test_async_path_expansion_bracket_17629(tmp_path: Path) -> None: df.write_parquet(path) assert_frame_equal(pl.scan_parquet(tmp_path / "[d]ata.parquet").collect(), df) + + +@pytest.mark.parametrize( + "method", + ["parquet", "csv", "ipc", "ndjson"], +) +def test_scan_in_memory(method: str) -> None: + f = io.BytesIO() + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + ) + + (getattr(df, f"write_{method}"))(f) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).collect() + assert_frame_equal(df, result) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).slice(1, 2).collect() + assert_frame_equal(df.slice(1, 2), result) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).slice(-1, 1).collect() + assert_frame_equal(df.slice(-1, 1), result) + + g = io.BytesIO() + (getattr(df, f"write_{method}"))(g) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).collect() + assert_frame_equal(df.vstack(df), result) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).slice(1, 2).collect() + assert_frame_equal(df.vstack(df).slice(1, 2), result) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).slice(-1, 1).collect() + assert_frame_equal(df.vstack(df).slice(-1, 1), result) + + +@pytest.mark.parametrize( + "method", + ["csv", "ndjson"], +) +def test_scan_stringio(method: str) -> None: + f = io.StringIO() + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + ) + + (getattr(df, f"write_{method}"))(f) + + f.seek(0) + result = (getattr(pl, f"scan_{method}"))(f).collect() + assert_frame_equal(df, result) + + g = io.StringIO() + (getattr(df, f"write_{method}"))(g) + + f.seek(0) + g.seek(0) + result = (getattr(pl, f"scan_{method}"))([f, g]).collect() + assert_frame_equal(df.vstack(df), result) + + +@pytest.mark.parametrize( + "method", + [pl.scan_parquet, pl.scan_csv, pl.scan_ipc, pl.scan_ndjson], +) +def test_empty_list(method: Callable[[list[str]], pl.LazyFrame]) -> None: + with pytest.raises(pl.exceptions.ComputeError, match="expected at least 1 source"): + _ = (method)([]).collect() diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 57354a4a336b..06af29659ba0 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -22,61 +22,61 @@ pytestmark = pytest.mark.slow() -@pytest.fixture() +@pytest.fixture def path_xls(io_files_path: Path) -> Path: # old excel 97-2004 format return io_files_path / "example.xls" -@pytest.fixture() +@pytest.fixture def path_xlsx(io_files_path: Path) -> Path: # modern excel format return io_files_path / "example.xlsx" -@pytest.fixture() +@pytest.fixture def path_xlsb(io_files_path: Path) -> Path: # excel binary format return io_files_path / "example.xlsb" -@pytest.fixture() +@pytest.fixture def path_ods(io_files_path: Path) -> Path: # open document spreadsheet return io_files_path / "example.ods" -@pytest.fixture() +@pytest.fixture def path_xls_empty(io_files_path: Path) -> Path: return io_files_path / "empty.xls" -@pytest.fixture() +@pytest.fixture def path_xlsx_empty(io_files_path: Path) -> Path: return io_files_path / "empty.xlsx" -@pytest.fixture() +@pytest.fixture def path_xlsx_mixed(io_files_path: Path) -> Path: return io_files_path / "mixed.xlsx" -@pytest.fixture() +@pytest.fixture def path_xlsb_empty(io_files_path: Path) -> Path: return io_files_path / "empty.xlsb" -@pytest.fixture() +@pytest.fixture def path_xlsb_mixed(io_files_path: Path) -> Path: return io_files_path / "mixed.xlsb" -@pytest.fixture() +@pytest.fixture def path_ods_empty(io_files_path: Path) -> Path: return io_files_path / "empty.ods" -@pytest.fixture() +@pytest.fixture def path_ods_mixed(io_files_path: Path) -> Path: return io_files_path / "mixed.ods" @@ -644,20 +644,22 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None: engine: ExcelSpreadsheetEngine for engine in ("calamine", "xlsx2csv"): - read_options = ( - {} + read_options, has_header = ( + ({}, True) if write_params.get("include_header", True) else ( - {"has_header": False, "new_columns": ["dtm", "str", "val"]} + {"new_columns": ["dtm", "str", "val"]} if engine == "xlsx2csv" - else {"header_row": None, "column_names": ["dtm", "str", "val"]} + else {"column_names": ["dtm", "str", "val"]}, + False, ) ) + fmt_strptime = "%Y-%m-%d" if write_params.get("dtype_formats", {}).get(pl.Date) == "dd-mm-yyyy": fmt_strptime = "%d-%m-%Y" - # write to an xlsx with polars, using various parameters... + # write to xlsx using various parameters... xls = BytesIO() _wb = df.write_excel(workbook=xls, worksheet="data", **write_params) @@ -667,6 +669,7 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None: sheet_name="data", engine=engine, read_options=read_options, + has_header=has_header, )[:3].select(df.columns[:3]) if engine == "xlsx2csv": @@ -727,6 +730,19 @@ def test_excel_write_compound_types(engine: ExcelSpreadsheetEngine) -> None: ] +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) +def test_excel_read_no_headers(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + {"colx": [1, 2, 3], "coly": ["aaa", "bbb", "ccc"], "colz": [0.5, 0.0, -1.0]} + ) + xls = BytesIO() + df.write_excel(xls, worksheet="data", include_header=False) + + xldf = pl.read_excel(xls, engine=engine, has_header=False) + expected = xldf.rename({"column_1": "colx", "column_2": "coly", "column_3": "colz"}) + assert_frame_equal(df, expected) + + @pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"]) def test_excel_write_sparklines(engine: ExcelSpreadsheetEngine) -> None: from xlsxwriter import Workbook diff --git a/py-polars/tests/unit/lazyframe/optimizations.py b/py-polars/tests/unit/lazyframe/optimizations.py index 1069fc3699be..2417edecdeb8 100644 --- a/py-polars/tests/unit/lazyframe/optimizations.py +++ b/py-polars/tests/unit/lazyframe/optimizations.py @@ -1,4 +1,7 @@ +import io + import polars as pl +from polars.testing import assert_frame_equal def test_remove_double_sort() -> None: @@ -6,3 +9,34 @@ def test_remove_double_sort() -> None: pl.LazyFrame({"a": [1, 2, 3, 3]}).sort("a").sort("a").explain().count("SORT") == 1 ) + + +def test_double_sort_maintain_order_18558() -> None: + df = pl.DataFrame( + { + "col1": [1, 2, 2, 4, 5, 6], + "col2": [2, 2, 0, 0, 2, None], + } + ) + + lf = df.lazy().sort("col2").sort("col1", maintain_order=True) + + expect = pl.DataFrame( + [ + pl.Series("col1", [1, 2, 2, 4, 5, 6], dtype=pl.Int64), + pl.Series("col2", [2, 0, 2, 0, 2, None], dtype=pl.Int64), + ] + ) + + assert_frame_equal(lf.collect(), expect) + + +def test_fast_count_alias_18581() -> None: + f = io.BytesIO() + f.write(b"a,b,c\n1,2,3\n4,5,6") + f.flush() + f.seek(0) + + df = pl.scan_csv(f).select(pl.len().alias("weird_name")).collect() + + assert_frame_equal(pl.DataFrame({"weird_name": 2}), df) diff --git a/py-polars/tests/unit/lazyframe/test_engine_selection.py b/py-polars/tests/unit/lazyframe/test_engine_selection.py index 760f63f6baa8..cb4156989b69 100644 --- a/py-polars/tests/unit/lazyframe/test_engine_selection.py +++ b/py-polars/tests/unit/lazyframe/test_engine_selection.py @@ -11,7 +11,7 @@ from polars._typing import EngineType -@pytest.fixture() +@pytest.fixture def df() -> pl.LazyFrame: return pl.LazyFrame({"a": [1, 2, 3]}) diff --git a/py-polars/tests/unit/lazyframe/test_lazyframe.py b/py-polars/tests/unit/lazyframe/test_lazyframe.py index 730a51d0b493..23394110b40e 100644 --- a/py-polars/tests/unit/lazyframe/test_lazyframe.py +++ b/py-polars/tests/unit/lazyframe/test_lazyframe.py @@ -1309,7 +1309,7 @@ def test_compare_schema_between_lazy_and_eager_6904() -> None: assert eager_result.shape == lazy_result.shape -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize( "dtype", [ diff --git a/py-polars/tests/unit/lazyframe/test_optimizations.py b/py-polars/tests/unit/lazyframe/test_optimizations.py new file mode 100644 index 000000000000..648bd3123787 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_optimizations.py @@ -0,0 +1,206 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_is_null_followed_by_all() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]}) + + expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().all() + ) + + assert ( + r'[[(col("val").count()) == (col("val").null_count())]]' in result_lf.explain() + ) + assert "is_null" not in result_lf + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().is_null().all()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "is_null" in non_optimized_result_plan + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [True]}) + result_df = lf.select(pl.col("val").is_null().all()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_null_followed_by_any() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().any() + ) + assert_frame_equal(expected_df, result_lf.collect()) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [False]}) + result_df = lf.select(pl.col("val").is_null().any()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_all() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]}) + + expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]}) + result_df = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").is_not_null().all()) + .collect() + ) + + assert_frame_equal(expected_df, result_df) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [True]}) + result_df = lf.select(pl.col("val").is_not_null().all()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_any() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_not_null().any() + ) + + assert ( + r'[[(col("val").null_count()) < (col("val").count())]]' in result_lf.explain() + ) + assert "is_not_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().is_not_null().any()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "is_not_null" in non_optimized_result_plan + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [False]}) + result_df = lf.select(pl.col("val").is_not_null().any()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_null_followed_by_sum() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [1, 1, 0]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().sum() + ) + + assert r'[col("val").null_count()]' in result_lf.explain() + assert "is_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32}) + result_df = lf.select(pl.col("val").is_null().sum()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_sum() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_not_null().sum() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "is_not_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").abs().is_not_null().sum() + ) + assert "null_count" not in non_optimized_result_lf.explain() + assert "is_not_null" in non_optimized_result_lf.explain() + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32}) + result_df = lf.select(pl.col("val").is_not_null().sum()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_drop_nulls_followed_by_len() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").drop_nulls().len() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "drop_nulls" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().drop_nulls().len()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "drop_nulls" in non_optimized_result_plan + + +def test_drop_nulls_followed_by_count() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").drop_nulls().count() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "drop_nulls" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().drop_nulls().count()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "drop_nulls" in non_optimized_result_plan diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index 515ce490693e..a82e389b4583 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -40,25 +40,20 @@ def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None: ], ) ) +@pytest.mark.filterwarnings("ignore") def test_lf_serde_roundtrip_json(lf: pl.LazyFrame) -> None: serialized = lf.serialize(format="json") result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json") assert_frame_equal(result, lf, categorical_as_str=True) -@pytest.fixture() +@pytest.fixture def lf() -> pl.LazyFrame: """Sample LazyFrame for testing serialization/deserialization.""" return pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).select("a").sum() -def test_lf_serde(lf: pl.LazyFrame) -> None: - serialized = lf.serialize() - assert isinstance(serialized, bytes) - result = pl.LazyFrame.deserialize(io.BytesIO(serialized)) - assert_frame_equal(result, lf) - - +@pytest.mark.filterwarnings("ignore") def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None: serialized = lf.serialize(format="json") assert isinstance(serialized, str) @@ -66,6 +61,13 @@ def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None: assert_frame_equal(result, lf) +def test_lf_serde(lf: pl.LazyFrame) -> None: + serialized = lf.serialize() + assert isinstance(serialized, bytes) + result = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + assert_frame_equal(result, lf) + + @pytest.mark.parametrize( ("format", "buf"), [ @@ -74,6 +76,7 @@ def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None: ("json", io.BytesIO()), ], ) +@pytest.mark.filterwarnings("ignore") def test_lf_serde_to_from_buffer( lf: pl.LazyFrame, format: SerializationFormat, buf: io.IOBase ) -> None: @@ -83,7 +86,7 @@ def test_lf_serde_to_from_buffer( assert_frame_equal(lf, result) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -100,7 +103,7 @@ def test_lf_deserialize_validation() -> None: pl.LazyFrame.deserialize(f, format="json") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_lf_serde_scan(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) path = tmp_path / "dataset.parquet" diff --git a/py-polars/tests/unit/meta/test_versions.py b/py-polars/tests/unit/meta/test_versions.py index 89e0c301807f..944f921ac1fb 100644 --- a/py-polars/tests/unit/meta/test_versions.py +++ b/py-polars/tests/unit/meta/test_versions.py @@ -5,7 +5,7 @@ import polars as pl -@pytest.mark.slow() +@pytest.mark.slow def test_show_versions(capsys: Any) -> None: pl.show_versions() diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py index e18902a11348..0dbc4effca18 100644 --- a/py-polars/tests/unit/ml/test_to_jax.py +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -20,7 +20,7 @@ from polars._typing import PolarsDataType -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: return pl.DataFrame( { diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py index c42c6f2ea666..2dc8f4141025 100644 --- a/py-polars/tests/unit/ml/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -15,7 +15,7 @@ pytestmark = pytest.mark.ci_only -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: return pl.DataFrame( { diff --git a/py-polars/tests/unit/operations/aggregation/test_aggregations.py b/py-polars/tests/unit/operations/aggregation/test_aggregations.py index ff6b27280382..8667b419d59e 100644 --- a/py-polars/tests/unit/operations/aggregation/test_aggregations.py +++ b/py-polars/tests/unit/operations/aggregation/test_aggregations.py @@ -113,7 +113,7 @@ def test_quantile() -> None: assert s.quantile(0.5, "higher") == 2 -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("tp", [int, float]) @pytest.mark.parametrize("n", [1, 2, 10, 100]) def test_quantile_vs_numpy(tp: type, n: int) -> None: @@ -437,7 +437,7 @@ def test_agg_filter_over_empty_df_13610() -> None: assert_frame_equal(out, expected) -@pytest.mark.slow() +@pytest.mark.slow def test_agg_empty_sum_after_filter_14734() -> None: f = ( pl.DataFrame({"a": [1, 2], "b": [1, 2]}) @@ -462,7 +462,7 @@ def test_agg_empty_sum_after_filter_14734() -> None: assert_frame_equal(expect, curr.select("b")) -@pytest.mark.slow() +@pytest.mark.slow def test_grouping_hash_14749() -> None: n_groups = 251 rows_per_group = 4 @@ -543,7 +543,7 @@ def test_group_count_over_null_column_15705() -> None: assert out["c"].to_list() == [0, 0, 0] -@pytest.mark.release() +@pytest.mark.release def test_min_max_2850() -> None: # https://github.com/pola-rs/polars/issues/2850 df = pl.DataFrame( diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index bc2bf2abbf46..b4054c695979 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -7,6 +7,7 @@ import pytest import polars as pl +import polars.selectors as cs from polars.exceptions import ComputeError from polars.testing import assert_frame_equal, assert_series_equal @@ -51,6 +52,20 @@ def test_all_any_horizontally() -> None: assert "horizontal" not in dfltr.explain().lower() +def test_empty_all_any_horizontally() -> None: + # any/all_horizontal don't allow empty input, but we can still trigger this + # by selecting an empty set of columns with pl.selectors. + df = pl.DataFrame({"x": [1, 2, 3]}) + assert_frame_equal( + df.select(pl.any_horizontal(cs.string().is_null())), + pl.DataFrame({"literal": False}), + ) + assert_frame_equal( + df.select(pl.all_horizontal(cs.string().is_null())), + pl.DataFrame({"literal": True}), + ) + + def test_all_any_single_input() -> None: df = pl.DataFrame({"a": [0, 1, None]}) out = df.select( diff --git a/py-polars/tests/unit/operations/aggregation/test_vertical.py b/py-polars/tests/unit/operations/aggregation/test_vertical.py index 3f2dbe080c07..fc74fdf59b65 100644 --- a/py-polars/tests/unit/operations/aggregation/test_vertical.py +++ b/py-polars/tests/unit/operations/aggregation/test_vertical.py @@ -57,7 +57,7 @@ def test_alias_for_col_agg(function: str, input: str) -> None: assert_expr_equal(result, expected, context) -@pytest.mark.release() +@pytest.mark.release def test_mean_overflow() -> None: np.random.seed(1) expected = 769.5607652 diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 4c6877d08694..57a8ce795dc1 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -343,10 +343,6 @@ def test_parse_apply_raw_functions() -> None: ): df1 = lf.select(pl.col("a").map_elements(func)).collect() df2 = lf.select(getattr(pl.col("a"), func_name)()).collect() - if func_name == "sign": - # note: Polars' 'sign' function returns an Int64, while numpy's - # 'sign' function returns a Float64 - df1 = df1.with_columns(pl.col("a").cast(pl.Int64)) assert_frame_equal(df1, df2) # test bare 'json.loads' diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index 7edc155e223f..ce147be9ef27 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -10,6 +10,10 @@ from polars.exceptions import PolarsInefficientMapWarning from polars.testing import assert_frame_equal, assert_series_equal +pytestmark = pytest.mark.filterwarnings( + "ignore::polars.exceptions.PolarsInefficientMapWarning" +) + def test_map_elements_infer_list() -> None: df = pl.DataFrame( @@ -180,17 +184,13 @@ def test_map_elements_skip_nulls() -> None: some_map = {None: "a", 1: "b"} s = pl.Series([None, 1]) - with pytest.warns( - PolarsInefficientMapWarning, - match=r"(?s)Replace this expression.*s\.map_elements\(lambda x:", - ): - assert s.map_elements( - lambda x: some_map[x], return_dtype=pl.String - ).to_list() == [None, "b"] + assert s.map_elements( + lambda x: some_map[x], return_dtype=pl.String, skip_nulls=True + ).to_list() == [None, "b"] - assert s.map_elements( - lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False - ).to_list() == ["a", "b"] + assert s.map_elements( + lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False + ).to_list() == ["a", "b"] def test_map_elements_object_dtypes() -> None: @@ -364,3 +364,22 @@ def test_unknown_map_elements() -> None: "Flour": [10.0, 100.0, 100.0, 20.0], } assert q.collect_schema().dtypes() == [pl.Int64, pl.Unknown] + + +def test_map_elements_list_dtype_18472() -> None: + s = pl.Series([[None], ["abc ", None]]) + result = s.map_elements(lambda s: [i.strip() if i else None for i in s]) + expected = pl.Series([[None], ["abc", None]]) + assert_series_equal(result, expected) + + +def test_map_elements_list_return_dtype() -> None: + s = pl.Series([[1], [2, 3]]) + return_dtype = pl.List(pl.UInt16) + + result = s.map_elements( + lambda s: [i + 1 for i in s], + return_dtype=return_dtype, + ) + expected = pl.Series([[2], [3, 4]], dtype=return_dtype) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/namespaces/conftest.py b/py-polars/tests/unit/operations/namespaces/conftest.py index 9aa3a121a35d..6c017aa02e6e 100644 --- a/py-polars/tests/unit/operations/namespaces/conftest.py +++ b/py-polars/tests/unit/operations/namespaces/conftest.py @@ -5,6 +5,6 @@ import pytest -@pytest.fixture() +@pytest.fixture def namespace_files_path() -> Path: return Path(__file__).parent / "files" diff --git a/py-polars/tests/unit/operations/namespaces/list/test_eval.py b/py-polars/tests/unit/operations/namespaces/list/test_eval.py new file mode 100644 index 000000000000..7432fe413384 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/list/test_eval.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import ( + StructFieldNotFoundError, +) +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_list_eval_dtype_inference() -> None: + grades = pl.DataFrame( + { + "student": ["bas", "laura", "tim", "jenny"], + "arithmetic": [10, 5, 6, 8], + "biology": [4, 6, 2, 7], + "geography": [8, 4, 9, 7], + } + ) + + rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16) + + # the .list.first() would fail if .list.eval did not correctly infer the output type + assert grades.with_columns( + pl.concat_list(pl.all().exclude("student")).alias("all_grades") + ).select( + pl.col("all_grades") + .list.eval(rank_pct, parallel=True) + .alias("grades_rank") + .list.first() + ).to_series().to_list() == [ + 0.3333333333333333, + 0.6666666666666666, + 0.6666666666666666, + 0.3333333333333333, + ] + + +def test_list_eval_categorical() -> None: + df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)}) + df = df.select( + pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null())) + ) + assert_series_equal( + df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical)) + ) + + +def test_list_eval_type_coercion() -> None: + last_non_null_value = pl.element().fill_null(3).last() + df = pl.DataFrame({"array_cols": [[1, None]]}) + + assert df.select( + pl.col("array_cols") + .list.eval(last_non_null_value, parallel=False) + .alias("col_last") + ).to_dict(as_series=False) == {"col_last": [[3]]} + + +def test_list_eval_all_null() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns( + pl.col("bar").cast(pl.List(pl.String)) + ) + + assert df.select(pl.col("bar").list.eval(pl.element())).to_dict( + as_series=False + ) == {"bar": [None, None, None]} + + +def test_empty_eval_dtype_5546() -> None: + # https://github.com/pola-rs/polars/issues/5546 + df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}]) + + dtype = df.dtypes[0] + + assert ( + df.limit(0).with_columns( + pl.col("a") + .list.eval(pl.element().filter(pl.first().struct.field("name") == 1)) + .alias("a_filtered") + ) + ).dtypes == [dtype, dtype] + + +def test_list_eval_gather_every_13410() -> None: + df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]}) + out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2))) + expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]}) + assert_frame_equal(out, expected) + + +def test_list_eval_err_raise_15653() -> None: + df = pl.DataFrame({"foo": [[]]}) + with pytest.raises(StructFieldNotFoundError): + df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz"))) + + +def test_list_eval_type_cast_11188() -> None: + df = pl.DataFrame( + [ + {"a": None}, + ], + schema={"a": pl.List(pl.Int64)}, + ) + assert df.select( + pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str") + ).schema == {"a_str": pl.List(pl.String)} + + +@pytest.mark.parametrize( + "data", + [ + {"a": [["0"], ["1"]]}, + {"a": [["0", "1"], ["2", "3"]]}, + {"a": [["0", "1"]]}, + {"a": [["0"]]}, + ], +) +@pytest.mark.parametrize( + "expr", + [ + pl.lit(""), + pl.format("test: {}", pl.element()), + ], +) +def test_list_eval_list_output_18510(data: dict[str, Any], expr: pl.Expr) -> None: + df = pl.DataFrame(data) + result = df.select(pl.col("a").list.eval(pl.lit(""))) + assert result.to_series().dtype == pl.List(pl.String) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 4e9bea71b792..f306bbff5d7b 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -1,17 +1,13 @@ from __future__ import annotations +import re from datetime import date, datetime import numpy as np import pytest import polars as pl -from polars.exceptions import ( - ComputeError, - OutOfBoundsError, - SchemaError, - StructFieldNotFoundError, -) +from polars.exceptions import ComputeError, OutOfBoundsError, SchemaError from polars.testing import assert_frame_equal, assert_series_equal @@ -164,6 +160,76 @@ def test_list_categorical_get() -> None: ) +def test_list_gather_wrong_indices_list_type() -> None: + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) + expected = pl.Series("a", [[1, 2], [4], [6, 9]]) + + # int8 + indices_series = pl.Series("indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int8)) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int16 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int16) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int32 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int32) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # int64 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.Int64) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint8 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt8) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint16 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt16) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint32 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt32) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + # uint64 + indices_series = pl.Series( + "indices", [[0, 1], [0], [0, 3]], dtype=pl.List(pl.UInt64) + ) + result = a.list.gather(indices=indices_series) + assert_series_equal(result, expected) + + df = pl.DataFrame( + { + "index": [["2"], ["2"], ["2"]], + "lists": [[3, 4, 5], [4, 5, 6], [7, 8, 9, 4]], + } + ) + with pytest.raises( + ComputeError, match=re.escape("cannot use dtype `list[str]` as an index") + ): + df.select(pl.col("lists").list.gather(pl.col("index"))) + + def test_contains() -> None: a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]]) out = a.list.contains(2) @@ -342,44 +408,6 @@ def test_slice() -> None: assert s.list.slice(-5, 2).to_list() == [[1], []] -def test_list_eval_dtype_inference() -> None: - grades = pl.DataFrame( - { - "student": ["bas", "laura", "tim", "jenny"], - "arithmetic": [10, 5, 6, 8], - "biology": [4, 6, 2, 7], - "geography": [8, 4, 9, 7], - } - ) - - rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16) - - # the .list.first() would fail if .list.eval did not correctly infer the output type - assert grades.with_columns( - pl.concat_list(pl.all().exclude("student")).alias("all_grades") - ).select( - pl.col("all_grades") - .list.eval(rank_pct, parallel=True) - .alias("grades_rank") - .list.first() - ).to_series().to_list() == [ - 0.3333333333333333, - 0.6666666666666666, - 0.6666666666666666, - 0.3333333333333333, - ] - - -def test_list_eval_categorical() -> None: - df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)}) - df = df.select( - pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null())) - ) - assert_series_equal( - df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical)) - ) - - def test_list_ternary_concat() -> None: df = pl.DataFrame( { @@ -423,17 +451,6 @@ def test_arr_contains_categorical() -> None: assert result.to_dict(as_series=False) == expected -def test_list_eval_type_coercion() -> None: - last_non_null_value = pl.element().fill_null(3).last() - df = pl.DataFrame({"array_cols": [[1, None]]}) - - assert df.select( - pl.col("array_cols") - .list.eval(last_non_null_value, parallel=False) - .alias("col_last") - ).to_dict(as_series=False) == {"col_last": [[3]]} - - def test_list_slice() -> None: df = pl.DataFrame( { @@ -476,21 +493,6 @@ def test_list_sliced_get_5186() -> None: assert_frame_equal(out1, out2) -def test_empty_eval_dtype_5546() -> None: - # https://github.com/pola-rs/polars/issues/5546 - df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}]) - - dtype = df.dtypes[0] - - assert ( - df.limit(0).with_columns( - pl.col("a") - .list.eval(pl.element().filter(pl.first().struct.field("name") == 1)) - .alias("a_filtered") - ) - ).dtypes == [dtype, dtype] - - def test_list_amortized_apply_explode_5812() -> None: s = pl.Series([None, [1, 3], [0, -3], [1, 2, 2]]) assert s.list.sum().to_list() == [None, 4, -3, 5] @@ -548,16 +550,6 @@ def test_list_gather() -> None: ] -def test_list_eval_all_null() -> None: - df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns( - pl.col("bar").cast(pl.List(pl.String)) - ) - - assert df.select(pl.col("bar").list.eval(pl.element())).to_dict( - as_series=False - ) == {"bar": [None, None, None]} - - def test_list_function_group_awareness() -> None: df = pl.DataFrame( { @@ -825,13 +817,6 @@ def test_list_get_logical_type() -> None: assert_series_equal(out, expected) -def test_list_eval_gater_every_13410() -> None: - df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]}) - out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2))) - expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]}) - assert_frame_equal(out, expected) - - def test_list_gather_every() -> None: df = pl.DataFrame( { @@ -896,24 +881,6 @@ def test_list_get_with_null() -> None: assert_frame_equal(out, expected) -def test_list_eval_err_raise_15653() -> None: - df = pl.DataFrame({"foo": [[]]}) - with pytest.raises(StructFieldNotFoundError): - df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz"))) - - def test_list_sum_bool_schema() -> None: q = pl.LazyFrame({"x": [[True, True, False]]}) assert q.select(pl.col("x").list.sum()).collect_schema()["x"] == pl.UInt32 - - -def test_list_eval_type_cast_11188() -> None: - df = pl.DataFrame( - [ - {"a": None}, - ], - schema={"a": pl.List(pl.Int64)}, - ) - assert df.select( - pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str") - ).schema == {"a_str": pl.List(pl.String)} diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py index 5c009a9776f0..fe47b8d07d2e 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_string.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -4,7 +4,12 @@ import polars as pl import polars.selectors as cs -from polars.exceptions import ComputeError, InvalidOperationError +from polars.exceptions import ( + ColumnNotFoundError, + ComputeError, + InvalidOperationError, + SchemaError, +) from polars.testing import assert_frame_equal, assert_series_equal @@ -1061,6 +1066,88 @@ def test_replace_many( ) +@pytest.mark.parametrize( + ("mapping", "case_insensitive", "expected"), + [ + ({}, False, "Tell me what you want"), + ({"me": "them"}, False, "Tell them what you want"), + ({"who": "them"}, False, "Tell me what you want"), + ({"me": "it", "you": "it"}, False, "Tell it what it want"), + ({"Me": "it", "you": "it"}, False, "Tell me what it want"), + ({"me": "you", "you": "me"}, False, "Tell you what me want"), + ({}, True, "Tell me what you want"), + ({"Me": "it", "you": "it"}, True, "Tell it what it want"), + ({"me": "you", "YOU": "me"}, True, "Tell you what me want"), + ], +) +def test_replace_many_mapping( + mapping: dict[str, str], + case_insensitive: bool, + expected: str, +) -> None: + df = pl.DataFrame({"text": ["Tell me what you want"]}) + # series + assert ( + expected + == df["text"] + .str.replace_many(mapping, ascii_case_insensitive=case_insensitive) + .item() + ) + # expr + assert ( + expected + == df.select( + pl.col("text").str.replace_many( + mapping, + ascii_case_insensitive=case_insensitive, + ) + ).item() + ) + + +def test_replace_many_invalid_inputs() -> None: + df = pl.DataFrame({"text": ["Tell me what you want"]}) + + # Ensure a string as the first argument is parsed as a column name. + with pytest.raises(ColumnNotFoundError, match="me"): + df.select(pl.col("text").str.replace_many("me", "you")) + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many(1, 2)) + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many([1], [2])) + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many(["me"], None)) + + with pytest.raises(TypeError): + df.select(pl.col("text").str.replace_many(["me"])) + + with pytest.raises( + InvalidOperationError, + match="expected the same amount of patterns as replacement strings", + ): + df.select(pl.col("text").str.replace_many(["a"], ["b", "c"])) + + s = df.to_series() + + with pytest.raises(ColumnNotFoundError, match="me"): + s.str.replace_many("me", "you") # type: ignore[arg-type] + + with pytest.raises(SchemaError): + df.select(pl.col("text").str.replace_many(["me"], None)) + + with pytest.raises(TypeError): + df.select(pl.col("text").str.replace_many(["me"])) + + with pytest.raises( + InvalidOperationError, + match="expected the same amount of patterns as replacement strings", + ): + s.str.replace_many(["a"], ["b", "c"]) + + def test_extract_all_count() -> None: df = pl.DataFrame({"foo": ["123 bla 45 asd", "xaz 678 910t", "boo", None]}) assert ( @@ -1489,23 +1576,52 @@ def test_splitn_expr() -> None: def test_titlecase() -> None: df = pl.DataFrame( { - "sing": [ + "misc": [ "welcome to my world", - "THERE'S NO TURNING BACK", "double space", "and\ta\t tab", - ] + "by jean-paul sartre, 'esq'", + "SOMETIMES/life/gives/you/a/2nd/chance", + ], } ) + expected = [ + "Welcome To My World", + "Double Space", + "And\tA\t Tab", + "By Jean-Paul Sartre, 'Esq'", + "Sometimes/Life/Gives/You/A/2nd/Chance", + ] + actual = df.select(pl.col("misc").str.to_titlecase()).to_series() + for ex, act in zip(expected, actual): + assert ex == act, f"{ex} != {act}" - assert df.select(pl.col("sing").str.to_titlecase()).to_dict(as_series=False) == { - "sing": [ - "Welcome To My World", - "There's No Turning Back", - "Double Space", - "And\tA\t Tab", - ] - } + df = pl.DataFrame( + { + "quotes": [ + "'e.t. phone home'", + "you talkin' to me?", + "i feel the need--the need for speed", + "to infinity,and BEYOND!", + "say 'what' again!i dare you - I\u00a0double-dare you!", + "What.we.got.here... is#failure#to#communicate", + ] + } + ) + expected_str = [ + "'E.T. Phone Home'", + "You Talkin' To Me?", + "I Feel The Need--The Need For Speed", + "To Infinity,And Beyond!", + "Say 'What' Again!I Dare You - I\u00a0Double-Dare You!", + "What.We.Got.Here... Is#Failure#To#Communicate", + ] + expected_py = [s.title() for s in df["quotes"].to_list()] + for ex_str, ex_py, act in zip( + expected_str, expected_py, df["quotes"].str.to_titlecase() + ): + assert ex_str == act, f"{ex_str} != {act}" + assert ex_py == act, f"{ex_py} != {act}" def test_string_replace_with_nulls_10124() -> None: diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index c9a43984cd45..fb4ddee68146 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from datetime import date, datetime, time, timedelta from typing import TYPE_CHECKING @@ -15,17 +16,17 @@ if TYPE_CHECKING: from zoneinfo import ZoneInfo - from polars._typing import TemporalLiteral, TimeUnit + from polars._typing import PolarsDataType, TemporalLiteral, TimeUnit else: from polars._utils.convert import string_to_zoneinfo as ZoneInfo -@pytest.fixture() +@pytest.fixture def series_of_int_dates() -> pl.Series: return pl.Series([10000, 20000, 30000], dtype=pl.Date) -@pytest.fixture() +@pytest.fixture def series_of_str_dates() -> pl.Series: return pl.Series(["2020-01-01 00:00:00.000000000", "2020-02-02 03:20:10.987654321"]) @@ -1350,3 +1351,118 @@ def test_dt_mean_deprecated() -> None: with pytest.deprecated_call(): result = s.dt.mean() assert result == s.mean() + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date, + pl.Datetime("ms"), + pl.Datetime("ms", "EST"), + pl.Datetime("us"), + pl.Datetime("us", "EST"), + pl.Datetime("ns"), + pl.Datetime("ns", "EST"), + ], +) +@pytest.mark.parametrize( + "value", + [ + date(1677, 9, 22), + date(1970, 1, 1), + date(2024, 2, 29), + date(2262, 4, 11), + ], +) +def test_literal_from_date( + value: date, + dtype: PolarsDataType, +) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + assert out.schema == OrderedDict({"literal": dtype}) + if dtype == pl.Datetime: + tz = ZoneInfo(dtype.time_zone) if dtype.time_zone is not None else None # type: ignore[union-attr] + value = datetime(value.year, value.month, value.day, tzinfo=tz) + assert out.item() == value + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date, + pl.Datetime("ms"), + pl.Datetime("ms", "EST"), + pl.Datetime("us"), + pl.Datetime("us", "EST"), + pl.Datetime("ns"), + pl.Datetime("ns", "EST"), + ], +) +@pytest.mark.parametrize( + "value", + [ + datetime(1677, 9, 22), + datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")), + datetime(1970, 1, 1), + datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")), + datetime(2024, 2, 29), + datetime(2024, 2, 29, tzinfo=ZoneInfo("EST")), + datetime(2262, 4, 11), + datetime(2262, 4, 11, tzinfo=ZoneInfo("EST")), + ], +) +def test_literal_from_datetime( + value: datetime, + dtype: pl.Date | pl.Datetime, +) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + if dtype == pl.Date: + value = value.date() # type: ignore[assignment] + elif dtype.time_zone is None and value.tzinfo is not None: # type: ignore[union-attr] + # update the dtype with the supplied time zone in the value + dtype = pl.Datetime(dtype.time_unit, str(value.tzinfo)) # type: ignore[union-attr] + elif dtype.time_zone is not None and value.tzinfo is None: # type: ignore[union-attr] + # cast from dt without tz to dtype with tz + value = value.replace(tzinfo=ZoneInfo(dtype.time_zone)) # type: ignore[union-attr] + + assert out.schema == OrderedDict({"literal": dtype}) + assert out.item() == value + + +@pytest.mark.parametrize( + "value", + [ + time(0), + time(hour=1), + time(hour=16, minute=43, microsecond=500), + time(hour=23, minute=59, second=59, microsecond=999999), + ], +) +def test_literal_from_time(value: time) -> None: + out = pl.select(pl.lit(value)) + assert out.schema == OrderedDict({"literal": pl.Time}) + assert out.item() == value + + +@pytest.mark.parametrize( + "dtype", + [ + None, + pl.Duration("ms"), + pl.Duration("us"), + pl.Duration("ns"), + ], +) +@pytest.mark.parametrize( + "value", + [ + timedelta(0), + timedelta(hours=1), + timedelta(days=-99999), + timedelta(days=99999), + ], +) +def test_literal_from_timedelta(value: time, dtype: pl.Duration | None) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + assert out.schema == OrderedDict({"literal": dtype or pl.Duration("us")}) + assert out.item() == value diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_round.py b/py-polars/tests/unit/operations/namespaces/temporal/test_round.py index 1ac7acc3edcd..49ed4328b8f0 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_round.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_round.py @@ -189,3 +189,51 @@ def test_round_datetime_w_expression(time_unit: TimeUnit) -> None: result = df.select(pl.col("a").dt.round(pl.col("b")))["a"] assert result[0] == datetime(2020, 1, 1) assert result[1] == datetime(2020, 1, 21) + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("ms", 0), + ("us", 0), + ("ns", 0), + ], +) +def test_round_negative_towards_epoch_18239(time_unit: TimeUnit, expected: int) -> None: + s = pl.Series([datetime(1970, 1, 1)], dtype=pl.Datetime(time_unit)) + s = s.dt.offset_by(f"-1{time_unit}") + result = s.dt.round(f"2{time_unit}").dt.timestamp(time_unit="ns").item() + assert result == expected + result = ( + s.dt.replace_time_zone("Europe/London") + .dt.round(f"2{time_unit}") + .dt.replace_time_zone(None) + .dt.timestamp(time_unit="ns") + .item() + ) + assert result == expected + + +@pytest.mark.parametrize( + ("time_unit", "expected"), + [ + ("ms", 2_000_000), + ("us", 2_000), + ("ns", 2), + ], +) +def test_round_positive_away_from_epoch_18239( + time_unit: TimeUnit, expected: int +) -> None: + s = pl.Series([datetime(1970, 1, 1)], dtype=pl.Datetime(time_unit)) + s = s.dt.offset_by(f"1{time_unit}") + result = s.dt.round(f"2{time_unit}").dt.timestamp(time_unit="ns").item() + assert result == expected + result = ( + s.dt.replace_time_zone("Europe/London") + .dt.round(f"2{time_unit}") + .dt.replace_time_zone(None) + .dt.timestamp(time_unit="ns") + .item() + ) + assert result == expected diff --git a/py-polars/tests/unit/operations/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py index 34f8964512d8..fc2fbc02648a 100644 --- a/py-polars/tests/unit/operations/namespaces/test_plot.py +++ b/py-polars/tests/unit/operations/namespaces/test_plot.py @@ -1,15 +1,8 @@ -from datetime import date - -import pytest - import polars as pl -# Calling `plot` the first time is slow -# https://github.com/pola-rs/polars/issues/13500 -pytestmark = pytest.mark.slow - -def test_dataframe_scatter() -> None: +def test_dataframe_plot() -> None: + # dry-run, check nothing errors df = pl.DataFrame( { "length": [1, 4, 6], @@ -17,24 +10,25 @@ def test_dataframe_scatter() -> None: "species": ["setosa", "setosa", "versicolor"], } ) - df.plot.scatter(x="length", y="width", by="species") + df.plot.line(x="length", y="width", color="species").to_json() + df.plot.point(x="length", y="width", size="species").to_json() + df.plot.scatter(x="length", y="width", size="species").to_json() + df.plot.bar(x="length", y="width", color="species").to_json() + df.plot.area(x="length", y="width", color="species").to_json() -def test_dataframe_line() -> None: - df = pl.DataFrame( - { - "date": [date(2020, 1, 2), date(2020, 1, 3), date(2020, 1, 3)], - "stock_1": [1, 4, 6], - "stock_2": [1, 5, 2], - } - ) - df.plot.line(x="date", y=["stock_1", "stock_2"]) +def test_series_plot() -> None: + # dry-run, check nothing errors + s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) + s.plot.kde().to_json() + s.plot.hist().to_json() + s.plot.line().to_json() + s.plot.point().to_json() -def test_series_hist() -> None: - s = pl.Series("values", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - s.plot.hist() +def test_empty_dataframe() -> None: + pl.DataFrame({"a": [], "b": []}).plot.point(x="a", y="b") -def test_empty_dataframe() -> None: - pl.DataFrame({"a": [], "b": []}).plot.scatter(x="a", y="b") +def test_nameless_series() -> None: + pl.Series([1, 2, 3]).plot.kde().to_json() diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 8e5bbfd69bd1..4ab0c18d4873 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -22,7 +22,7 @@ from polars._typing import ClosedInterval, PolarsDataType, TimeUnit -@pytest.fixture() +@pytest.fixture def example_df() -> pl.DataFrame: return pl.DataFrame( { @@ -589,6 +589,44 @@ def test_rolling_cov_corr() -> None: assert res["corr"][:2] == [None] * 2 +def test_rolling_cov_corr_nulls() -> None: + df1 = pl.DataFrame( + {"a": [1.06, 1.07, 0.93, 0.78, 0.85], "lag_a": [1.0, 1.06, 1.07, 0.93, 0.78]} + ) + df2 = pl.DataFrame( + { + "a": [1.0, 1.06, 1.07, 0.93, 0.78, 0.85], + "lag_a": [None, 1.0, 1.06, 1.07, 0.93, 0.78], + } + ) + + val_1 = df1.select( + pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) + val_2 = df2.select( + pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) + + df1_expected = pl.DataFrame({"a": [None, None, None, None, 0.62204709]}) + df2_expected = pl.DataFrame({"a": [None, None, None, None, None, 0.62204709]}) + + assert_frame_equal(val_1, df1_expected, atol=0.0000001) + assert_frame_equal(val_2, df2_expected, atol=0.0000001) + + val_1 = df1.select( + pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) + val_2 = df2.select( + pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) + + df1_expected = pl.DataFrame({"a": [None, None, None, None, 0.009445]}) + df2_expected = pl.DataFrame({"a": [None, None, None, None, None, 0.009445]}) + + assert_frame_equal(val_1, df1_expected, atol=0.0000001) + assert_frame_equal(val_2, df2_expected, atol=0.0000001) + + @pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) def test_rolling_empty_window_9406(time_unit: TimeUnit) -> None: datecol = pl.Series( @@ -873,7 +911,7 @@ def test_rolling_median() -> None: ) -@pytest.mark.slow() +@pytest.mark.slow def test_rolling_median_2() -> None: np.random.seed(12) n = 1000 diff --git a/py-polars/tests/unit/operations/test_bitwise.py b/py-polars/tests/unit/operations/test_bitwise.py index b6f2fd12edbb..e69fe2b218bd 100644 --- a/py-polars/tests/unit/operations/test_bitwise.py +++ b/py-polars/tests/unit/operations/test_bitwise.py @@ -8,3 +8,12 @@ def test_bitwise_integral_schema(op: str) -> None: df = pl.LazyFrame({"a": [1, 2], "b": [3, 4]}) q = df.select(getattr(pl.col("a"), op)(pl.col("b"))) assert q.collect_schema()["a"] == df.collect_schema()["a"] + + +@pytest.mark.parametrize("op", ["and_", "or_", "xor"]) +def test_bitwise_single_null_value_schema(op: str) -> None: + df = pl.DataFrame({"a": [True, True]}) + q = df.select(getattr(pl.col("a"), op)(None)) + result_schema = q.collect_schema() + assert result_schema.len() == 1 + assert "a" in result_schema diff --git a/py-polars/tests/unit/operations/test_clip.py b/py-polars/tests/unit/operations/test_clip.py index af6b1965d039..906f217a4ab6 100644 --- a/py-polars/tests/unit/operations/test_clip.py +++ b/py-polars/tests/unit/operations/test_clip.py @@ -9,7 +9,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def clip_exprs() -> list[pl.Expr]: return [ pl.col("a").clip(pl.col("min"), pl.col("max")).alias("clip"), diff --git a/py-polars/tests/unit/operations/test_comparison.py b/py-polars/tests/unit/operations/test_comparison.py index eb74f374f6a1..42535cff85a6 100644 --- a/py-polars/tests/unit/operations/test_comparison.py +++ b/py-polars/tests/unit/operations/test_comparison.py @@ -307,7 +307,7 @@ def verify_total_ordering_broadcast( ] -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("lhs", INTERESTING_FLOAT_VALUES) @pytest.mark.parametrize("rhs", INTERESTING_FLOAT_VALUES) def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> None: @@ -335,7 +335,7 @@ def test_total_ordering_float_series(lhs: float | None, rhs: float | None) -> No ] -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("lhs", INTERESTING_STRING_VALUES) @pytest.mark.parametrize("rhs", INTERESTING_STRING_VALUES) def test_total_ordering_string_series(lhs: str | None, rhs: str | None) -> None: @@ -347,7 +347,7 @@ def test_total_ordering_string_series(lhs: str | None, rhs: str | None) -> None: verify_total_ordering_broadcast(lhs, rhs, "", pl.String) -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("str_lhs", INTERESTING_STRING_VALUES) @pytest.mark.parametrize("str_rhs", INTERESTING_STRING_VALUES) def test_total_ordering_binary_series(str_lhs: str | None, str_rhs: str | None) -> None: diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 24e65ac1dc6d..14aefa93c3c1 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -167,7 +167,7 @@ def test_list_struct_explode_6905() -> None: }, schema={"group": pl.List(pl.Struct([pl.Field("params", pl.List(pl.Int32))]))}, )["group"].list.explode().to_list() == [ - {"params": None}, + None, {"params": [1]}, {"params": []}, ] @@ -447,3 +447,8 @@ def test_explode_17648() -> None: .with_columns(pl.int_ranges(pl.col("a").list.len()).alias("count")) .explode("a", "count") ).to_dict(as_series=False) == {"a": [2, 6, 7, 3, 9, 2], "count": [0, 1, 2, 0, 1, 2]} + + +def test_explode_struct_nulls() -> None: + df = pl.DataFrame({"A": [[{"B": 1}], [None], []]}) + assert df.explode("A").to_dict(as_series=False) == {"A": [{"B": 1}, None, None]} diff --git a/py-polars/tests/unit/operations/test_filter.py b/py-polars/tests/unit/operations/test_filter.py index df796b44b991..a09c625cbe5b 100644 --- a/py-polars/tests/unit/operations/test_filter.py +++ b/py-polars/tests/unit/operations/test_filter.py @@ -257,7 +257,7 @@ def test_filter_horizontal_selector_15428() -> None: assert_frame_equal(df, expected_df) -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize( "dtype", [pl.Boolean, pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.String] ) @@ -285,3 +285,16 @@ def test_filter_group_aware_17030() -> None: (group_count > 2) & (group_cum_count > 1) & (group_cum_count < group_count) ) assert df.filter(filter_expr)["foo"].to_list() == ["1", "2"] + + +def test_invalid_filter_18295() -> None: + codes = ["a"] * 5 + ["b"] * 5 + values = list(range(-2, 3)) + list(range(2, -3, -1)) + df = pl.DataFrame({"code": codes, "value": values}) + with pytest.raises(pl.exceptions.ShapeError): + df.group_by("code").agg( + pl.col("value") + .ewm_mean(span=2, ignore_nulls=True) + .tail(3) + .filter(pl.col("value") > 0), + ).sort("code") diff --git a/py-polars/tests/unit/operations/test_gather.py b/py-polars/tests/unit/operations/test_gather.py index fab07dc71956..d56e5d80858d 100644 --- a/py-polars/tests/unit/operations/test_gather.py +++ b/py-polars/tests/unit/operations/test_gather.py @@ -148,3 +148,11 @@ def test_chunked_gather_phys_repr_17446() -> None: dfb = pl.concat([dfb, dfb]) assert dfa.join(dfb, how="left", on=pl.col("replace_unique_id")).shape == (4, 2) + + +def test_gather_str_col_18099() -> None: + df = pl.DataFrame({"foo": [1, 2, 3], "idx": [0, 0, 1]}) + assert df.with_columns(pl.col("foo").gather("idx")).to_dict(as_series=False) == { + "foo": [1, 1, 2], + "idx": [0, 0, 1], + } diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index a448a25c2370..09864b329c22 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -249,7 +249,7 @@ def test_group_by_median_by_dtype( assert_frame_equal(result, df_expected) -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: return pl.DataFrame( { @@ -642,7 +642,7 @@ def test_group_by_binary_agg_with_literal() -> None: assert out.to_dict(as_series=False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]} -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32]) def test_overflow_mean_partitioned_group_by_5194(dtype: PolarsDataType) -> None: df = pl.DataFrame( @@ -861,7 +861,7 @@ def test_group_by_apply_first_input_is_literal() -> None: pow = df.group_by("g").agg(2 ** pl.col("x")) assert pow.sort("g").to_dict(as_series=False) == { "g": [1, 2], - "x": [[2.0, 4.0], [8.0, 16.0, 32.0]], + "literal": [[2.0, 4.0], [8.0, 16.0, 32.0]], } @@ -924,7 +924,7 @@ def test_group_by_multiple_null_cols_15623() -> None: assert df.is_empty() -@pytest.mark.release() +@pytest.mark.release def test_categorical_vs_str_group_by() -> None: # this triggers the perfect hash table s = pl.Series("a", np.random.randint(0, 50, 100)) @@ -951,7 +951,7 @@ def test_categorical_vs_str_group_by() -> None: ) -@pytest.mark.release() +@pytest.mark.release def test_boolean_min_max_agg() -> None: np.random.seed(0) idx = np.random.randint(0, 500, 1000) diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py new file mode 100644 index 000000000000..7b3ddb279fec --- /dev/null +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from hypothesis.strategies import DrawFn, SearchStrategy + + +def test_self_join() -> None: + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [100, 140, 80, 90], + "cost": [6, 11, 10, 5], + "cores": [4, 2, 1, 4], + } + ) + + actual = west.join_where( + west, pl.col("time") > pl.col("time"), pl.col("cost") < pl.col("cost") + ) + + expected = pl.DataFrame( + { + "t_id": [742, 404], + "time": [90, 100], + "cost": [5, 6], + "cores": [4, 4], + "t_id_right": [676, 676], + "time_right": [80, 80], + "cost_right": [10, 10], + "cores_right": [1, 1], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +def test_basic_ie_join() -> None: + east = pl.DataFrame( + { + "id": [100, 101, 102], + "dur": [140, 100, 90], + "rev": [12, 12, 5], + "cores": [2, 8, 4], + } + ) + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [100, 140, 80, 90], + "cost": [6, 11, 10, 5], + "cores": [4, 2, 1, 4], + } + ) + + actual = east.join_where( + west, pl.col("dur") < pl.col("time"), pl.col("rev") > pl.col("cost") + ) + + expected = pl.DataFrame( + { + "id": [101], + "dur": [100], + "rev": [12], + "cores": [8], + "t_id": [498], + "time": [140], + "cost": [11], + "cores_right": [2], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + offset=st.integers(-6, 5), + length=st.integers(0, 6), +) +def test_ie_join_with_slice(offset: int, length: int) -> None: + east = pl.DataFrame( + { + "id": [100, 101, 102], + "dur": [120, 140, 160], + "rev": [12, 14, 16], + "cores": [2, 8, 4], + } + ).lazy() + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [90, 130, 150, 170], + "cost": [9, 13, 15, 16], + "cores": [4, 2, 1, 4], + } + ).lazy() + + actual = ( + east.join_where( + west, pl.col("dur") < pl.col("time"), pl.col("rev") < pl.col("cost") + ) + .slice(offset, length) + .collect() + ) + + expected_full = pl.DataFrame( + { + "id": [101, 101, 100, 100, 100], + "dur": [140, 140, 120, 120, 120], + "rev": [14, 14, 12, 12, 12], + "cores": [8, 8, 2, 2, 2], + "t_id": [676, 742, 498, 676, 742], + "time": [150, 170, 130, 150, 170], + "cost": [15, 16, 13, 15, 16], + "cores_right": [1, 4, 2, 1, 4], + } + ) + # The ordering of the result is arbitrary, so we can + # only verify that each row of the slice is present in the full expected result. + assert len(actual) == len(expected_full.slice(offset, length)) + + expected_rows = set(expected_full.iter_rows()) + for row in actual.iter_rows(): + assert row in expected_rows, f"{row} not in expected rows" + + +def test_ie_join_with_expressions() -> None: + east = pl.DataFrame( + { + "id": [100, 101, 102], + "dur": [70, 50, 45], + "rev": [12, 12, 5], + "cores": [2, 8, 4], + } + ) + west = pl.DataFrame( + { + "t_id": [404, 498, 676, 742], + "time": [100, 140, 80, 90], + "cost": [12, 22, 20, 10], + "cores": [4, 2, 1, 4], + } + ) + + actual = east.join_where( + west, + (pl.col("dur") * 2) < pl.col("time"), + pl.col("rev") > (pl.col("cost").cast(pl.Int32) // 2).cast(pl.Int64), + ) + + expected = pl.DataFrame( + { + "id": [101], + "dur": [50], + "rev": [12], + "cores": [8], + "t_id": [498], + "time": [140], + "cost": [22], + "cores_right": [2], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +def test_join_where_predicates() -> None: + left = pl.DataFrame( + { + "id": [0, 1, 2, 3, 4, 5], + "group": [0, 0, 0, 1, 1, 1], + "time": [ + datetime(2024, 8, 26, 15, 34, 30), + datetime(2024, 8, 26, 15, 35, 30), + datetime(2024, 8, 26, 15, 36, 30), + datetime(2024, 8, 26, 15, 37, 30), + datetime(2024, 8, 26, 15, 38, 0), + datetime(2024, 8, 26, 15, 39, 0), + ], + } + ) + right = pl.DataFrame( + { + "id": [0, 1, 2], + "group": [0, 1, 1], + "start_time": [ + datetime(2024, 8, 26, 15, 34, 0), + datetime(2024, 8, 26, 15, 35, 0), + datetime(2024, 8, 26, 15, 38, 0), + ], + "end_time": [ + datetime(2024, 8, 26, 15, 36, 0), + datetime(2024, 8, 26, 15, 37, 0), + datetime(2024, 8, 26, 15, 39, 0), + ], + } + ) + + actual = left.join_where( + right, + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + ).select("id", "id_right") + + expected = pl.DataFrame( + { + "id": [0, 1, 1, 2, 4], + "id_right": [0, 0, 1, 1, 2], + } + ) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") == pl.col("group"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + + explained = q.explain() + assert "INNER JOIN" in explained + assert "FILTER" in explained + actual = q.collect() + + expected = ( + left.join(right, how="cross") + .filter( + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") == pl.col("group_right"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + assert_frame_equal(actual, expected, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") != pl.col("group"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + + explained = q.explain() + assert "IEJOIN" in explained + assert "FILTER" in explained + actual = q.collect() + + expected = ( + left.join(right, how="cross") + .filter( + pl.col("time") >= pl.col("start_time"), + pl.col("time") < pl.col("end_time"), + pl.col("group") != pl.col("group_right"), + ) + .select("id", "id_right", "group") + .sort("id") + ) + assert_frame_equal(actual, expected, check_exact=True) + + q = ( + left.lazy() + .join_where( + right.lazy(), + pl.col("group") != pl.col("group"), + ) + .select("id", "group", "group_right") + .sort("id") + .select("group", "group_right") + ) + + explained = q.explain() + assert "CROSS" in explained + assert "FILTER" in explained + actual = q.collect() + assert actual.to_dict(as_series=False) == { + "group": [0, 0, 0, 0, 0, 0, 1, 1, 1], + "group_right": [1, 1, 1, 1, 1, 1, 0, 0, 0], + } + + +def _inequality_expression(col1: str, op: str, col2: str) -> pl.Expr: + if op == "<": + return pl.col(col1) < pl.col(col2) + elif op == "<=": + return pl.col(col1) <= pl.col(col2) + elif op == ">": + return pl.col(col1) > pl.col(col2) + elif op == ">=": + return pl.col(col1) >= pl.col(col2) + else: + message = f"Invalid operator '{op}'" + raise ValueError(message) + + +def operators() -> SearchStrategy[str]: + valid_operators = ["<", "<=", ">", ">="] + return st.sampled_from(valid_operators) + + +@st.composite +def east_df( + draw: DrawFn, with_nulls: bool = False, use_floats: bool = False +) -> pl.DataFrame: + height = draw(st.integers(min_value=0, max_value=20)) + + if use_floats: + dur_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + rev_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + dur_dtype: type[pl.DataType] = pl.Float32 + rev_dtype: type[pl.DataType] = pl.Float32 + else: + dur_strategy = st.integers(min_value=100, max_value=105) + rev_strategy = st.integers(min_value=9, max_value=13) + dur_dtype = pl.Int64 + rev_dtype = pl.Int64 + + if with_nulls: + dur_strategy = dur_strategy | st.none() + rev_strategy = rev_strategy | st.none() + + cores_strategy = st.integers(min_value=1, max_value=10) + + ids = np.arange(0, height) + dur = draw(st.lists(dur_strategy, min_size=height, max_size=height)) + rev = draw(st.lists(rev_strategy, min_size=height, max_size=height)) + cores = draw(st.lists(cores_strategy, min_size=height, max_size=height)) + + return pl.DataFrame( + [ + pl.Series("id", ids, dtype=pl.Int64), + pl.Series("dur", dur, dtype=dur_dtype), + pl.Series("rev", rev, dtype=rev_dtype), + pl.Series("cores", cores, dtype=pl.Int64), + ] + ) + + +@st.composite +def west_df( + draw: DrawFn, with_nulls: bool = False, use_floats: bool = False +) -> pl.DataFrame: + height = draw(st.integers(min_value=0, max_value=20)) + + if use_floats: + time_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + cost_strategy: SearchStrategy[Any] = st.floats(allow_nan=True) + time_dtype: type[pl.DataType] = pl.Float32 + cost_dtype: type[pl.DataType] = pl.Float32 + else: + time_strategy = st.integers(min_value=100, max_value=105) + cost_strategy = st.integers(min_value=9, max_value=13) + time_dtype = pl.Int64 + cost_dtype = pl.Int64 + + if with_nulls: + time_strategy = time_strategy | st.none() + cost_strategy = cost_strategy | st.none() + + cores_strategy = st.integers(min_value=1, max_value=10) + + t_id = np.arange(100, 100 + height) + time = draw(st.lists(time_strategy, min_size=height, max_size=height)) + cost = draw(st.lists(cost_strategy, min_size=height, max_size=height)) + cores = draw(st.lists(cores_strategy, min_size=height, max_size=height)) + + return pl.DataFrame( + [ + pl.Series("t_id", t_id, dtype=pl.Int64), + pl.Series("time", time, dtype=time_dtype), + pl.Series("cost", cost, dtype=cost_dtype), + pl.Series("cores", cores, dtype=pl.Int64), + ] + ) + + +@given( + east=east_df(), + west=west_df(), + op1=operators(), + op2=operators(), +) +def test_ie_join(east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str) -> None: + expr0 = _inequality_expression("dur", op1, "time") + expr1 = _inequality_expression("rev", op2, "cost") + + actual = east.join_where(west, expr0, expr1) + + expected = east.join(west, how="cross").filter(expr0 & expr1) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + east=east_df(with_nulls=True), + west=west_df(with_nulls=True), + op1=operators(), + op2=operators(), +) +def test_ie_join_with_nulls( + east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str +) -> None: + expr0 = _inequality_expression("dur", op1, "time") + expr1 = _inequality_expression("rev", op2, "cost") + + actual = east.join_where(west, expr0, expr1) + + expected = east.join(west, how="cross").filter(expr0 & expr1) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +@given( + east=east_df(use_floats=True), + west=west_df(use_floats=True), + op1=operators(), + op2=operators(), +) +def test_ie_join_with_floats( + east: pl.DataFrame, west: pl.DataFrame, op1: str, op2: str +) -> None: + expr0 = _inequality_expression("dur", op1, "time") + expr1 = _inequality_expression("rev", op2, "cost") + + actual = east.join_where(west, expr0, expr1) + + expected = east.join(west, how="cross").filter(expr0 & expr1) + assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) + + +def test_raise_on_suffixed_predicate_18604() -> None: + df = pl.DataFrame({"id": [1, 2]}) + with pytest.raises(pl.exceptions.ColumnNotFoundError): + df.join_where(df, pl.col("id") >= pl.col("id_right")) + + +def test_raise_on_multiple_binary_comparisons() -> None: + df = pl.DataFrame({"id": [1, 2]}) + with pytest.raises(pl.exceptions.InvalidOperationError): + df.join_where( + df, (pl.col("id") < pl.col("id")) & (pl.col("id") >= pl.col("id")) + ) diff --git a/py-polars/tests/unit/operations/test_interpolate_by.py b/py-polars/tests/unit/operations/test_interpolate_by.py index 423992abeadd..98ee656fdaed 100644 --- a/py-polars/tests/unit/operations/test_interpolate_by.py +++ b/py-polars/tests/unit/operations/test_interpolate_by.py @@ -28,6 +28,8 @@ pl.Int32, pl.UInt64, pl.UInt32, + pl.Float32, + pl.Float64, ], ) @pytest.mark.parametrize( @@ -116,22 +118,42 @@ def test_interpolate_by_leading_nulls() -> None: assert_frame_equal(result, expected) -def test_interpolate_by_trailing_nulls() -> None: - df = pl.DataFrame( - { - "times": [ - date(2020, 1, 1), - date(2020, 1, 3), - date(2020, 1, 10), - date(2020, 1, 11), - date(2020, 1, 12), - date(2020, 1, 13), - ], - "values": [1, None, None, 5, None, None], - } - ) +@pytest.mark.parametrize("dataset", ["floats", "dates"]) +def test_interpolate_by_trailing_nulls(dataset: str) -> None: + input_data = { + "dates": pl.DataFrame( + { + "times": [ + date(2020, 1, 1), + date(2020, 1, 3), + date(2020, 1, 10), + date(2020, 1, 11), + date(2020, 1, 12), + date(2020, 1, 13), + ], + "values": [1, None, None, 5, None, None], + } + ), + "floats": pl.DataFrame( + { + "times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1], + "values": [1, None, None, 5, None, None], + } + ), + } + + expected_data = { + "dates": pl.DataFrame( + {"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]} + ), + "floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}), + } + + df = input_data[dataset] + expected = expected_data[dataset] + result = df.select(pl.col("values").interpolate_by("times")) - expected = pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}) + assert_frame_equal(result, expected) result = ( df.sort("times", descending=True) @@ -142,16 +164,28 @@ def test_interpolate_by_trailing_nulls() -> None: assert_frame_equal(result, expected) -@given(data=st.data()) -def test_interpolate_vs_numpy(data: st.DataObject) -> None: +@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64])) +def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None: + if x_dtype == pl.Float64: + by_strategy = st.floats( + min_value=-1e150, + max_value=1e150, + allow_nan=False, + allow_infinity=False, + allow_subnormal=False, + ) + else: + by_strategy = None + dataframe = ( data.draw( dataframes( [ column( "ts", - dtype=pl.Date, + dtype=x_dtype, allow_null=False, + strategy=by_strategy, ), column( "value", @@ -166,13 +200,24 @@ def test_interpolate_vs_numpy(data: st.DataObject) -> None: .fill_nan(None) .unique("ts") ) + + if x_dtype == pl.Float64: + assume(not dataframe["ts"].is_nan().any()) + assume(not dataframe["ts"].is_null().any()) + assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any()) + assume(not dataframe["value"].is_null().all()) assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any()) + + dataframe = dataframe.sort("ts") + result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"] mask = dataframe["value"].is_not_null() - x = dataframe["ts"].to_numpy().astype("int64") - xp = dataframe["ts"].filter(mask).to_numpy().astype("int64") + + np_dtype = "int64" if x_dtype == pl.Date else "float64" + x = dataframe["ts"].to_numpy().astype(np_dtype) + xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype) yp = dataframe["value"].filter(mask).to_numpy().astype("float64") interp = np.interp(x, xp, yp) # Polars preserves nulls on boundaries, but NumPy doesn't. diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index 34641470b402..c4b88745dad6 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -136,7 +136,7 @@ def test_is_in_series() -> None: with pytest.raises( InvalidOperationError, - match=r"`is_in` cannot check for String values in Int64 data", + match=r"'is_in' cannot check for String values in Int64 data", ): df.select(pl.col("b").is_in(["x", "x"])) @@ -192,12 +192,12 @@ def test_is_in_float(dtype: PolarsDataType) -> None: ( pl.DataFrame({"a": ["1", "2"], "b": [[1, 2], [3, 4]]}), None, - r"`is_in` cannot check for String values in List\(Int64\) data", + r"'is_in' cannot check for String values in List\(Int64\) data", ), ( pl.DataFrame({"a": [date.today(), None], "b": [[1, 2], [3, 4]]}), None, - r"`is_in` cannot check for Date values in List\(Int64\) data", + r"'is_in' cannot check for Date values in List\(Int64\) data", ), ], ) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 069c9fc305c6..0a9d7ab2d9fd 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -618,7 +618,7 @@ def test_full_outer_join_list_() -> None: } -@pytest.mark.slow() +@pytest.mark.slow def test_join_validation() -> None: def test_each_join_validation( unique: pl.DataFrame, duplicate: pl.DataFrame, on: str, how: JoinStrategy @@ -788,8 +788,7 @@ def test_join_on_wildcard_error() -> None: df = pl.DataFrame({"x": [1]}) df2 = pl.DataFrame({"x": [1], "y": [2]}) with pytest.raises( - ComputeError, - match="wildcard column selection not supported at this point", + InvalidOperationError, ): df.join(df2, on=pl.all()) @@ -798,8 +797,7 @@ def test_join_on_nth_error() -> None: df = pl.DataFrame({"x": [1]}) df2 = pl.DataFrame({"x": [1], "y": [2]}) with pytest.raises( - ComputeError, - match=r"nth column selection not supported at this point \(n=0\)", + InvalidOperationError, ): df.join(df2, on=pl.first()) @@ -850,7 +848,7 @@ def test_join_list_non_numeric() -> None: } -@pytest.mark.slow() +@pytest.mark.slow def test_join_4_columns_with_validity() -> None: # join on 4 columns so we trigger combine validities # use 138 as that is 2 u64 and a remainder @@ -872,7 +870,7 @@ def test_join_4_columns_with_validity() -> None: ) -@pytest.mark.release() +@pytest.mark.release def test_cross_join() -> None: # triggers > 100 rows implementation # https://github.com/pola-rs/polars/blob/5f5acb2a523ce01bc710768b396762b8e69a9e07/polars/polars-core/src/frame/cross_join.rs#L34 @@ -883,7 +881,7 @@ def test_cross_join() -> None: assert_frame_equal(df2.join(df1, how="cross").slice(0, 100), out) -@pytest.mark.release() +@pytest.mark.release def test_cross_join_slice_pushdown() -> None: # this will likely go out of memory if we did not pushdown the slice df = ( diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py index 98f781a45432..ce8e5644bc77 100644 --- a/py-polars/tests/unit/operations/test_random.py +++ b/py-polars/tests/unit/operations/test_random.py @@ -116,27 +116,6 @@ def test_sample_series() -> None: assert len(s.sample(n=10, with_replacement=True, seed=0)) == 10 -def test_rank_random_expr() -> None: - df = pl.from_dict( - {"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]} - ) - - df_ranks1 = df.with_columns( - pl.col("c").rank(method="random", seed=1).over("a").alias("rank") - ) - df_ranks2 = df.with_columns( - pl.col("c").rank(method="random", seed=1).over("a").alias("rank") - ) - assert_frame_equal(df_ranks1, df_ranks2) - - -def test_rank_random_series() -> None: - s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) - assert_series_equal( - s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=pl.UInt32) - ) - - def test_shuffle_expr() -> None: # pl.set_random_seed should lead to reproducible results. s = pl.Series("a", range(20)) diff --git a/py-polars/tests/unit/operations/test_rank.py b/py-polars/tests/unit/operations/test_rank.py new file mode 100644 index 000000000000..6f83663875b6 --- /dev/null +++ b/py-polars/tests/unit/operations/test_rank.py @@ -0,0 +1,97 @@ +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_rank_nulls() -> None: + assert pl.Series([]).rank().to_list() == [] + assert pl.Series([None]).rank().to_list() == [None] + assert pl.Series([None, None]).rank().to_list() == [None, None] + + +def test_rank_random_expr() -> None: + df = pl.from_dict( + {"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]} + ) + + df_ranks1 = df.with_columns( + pl.col("c").rank(method="random", seed=1).over("a").alias("rank") + ) + df_ranks2 = df.with_columns( + pl.col("c").rank(method="random", seed=1).over("a").alias("rank") + ) + assert_frame_equal(df_ranks1, df_ranks2) + + +def test_rank_random_series() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + assert_series_equal( + s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=pl.UInt32) + ) + + +def test_rank_df() -> None: + df = pl.DataFrame( + { + "a": [1, 1, 2, 2, 3], + } + ) + + s = df.select(pl.col("a").rank(method="average").alias("b")).to_series() + assert s.to_list() == [1.5, 1.5, 3.5, 3.5, 5.0] + assert s.dtype == pl.Float64 + + s = df.select(pl.col("a").rank(method="max").alias("b")).to_series() + assert s.to_list() == [2, 2, 4, 4, 5] + assert s.dtype == pl.get_index_type() + + +def test_rank_so_4109() -> None: + # also tests ranks null behavior + df = pl.from_dict( + { + "id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "rank": [None, 3, 2, 4, 1, 4, 3, 2, 1, None, 3, 4, 4, 1, None, 3], + } + ).sort(by=["id", "rank"]) + + assert df.group_by("id").agg( + [ + pl.col("rank").alias("original"), + pl.col("rank").rank(method="dense").alias("dense"), + pl.col("rank").rank(method="average").alias("average"), + ] + ).to_dict(as_series=False) == { + "id": [1, 2, 3, 4], + "original": [[None, 2, 3, 4], [1, 2, 3, 4], [None, 1, 3, 4], [None, 1, 3, 4]], + "dense": [[None, 1, 2, 3], [1, 2, 3, 4], [None, 1, 2, 3], [None, 1, 2, 3]], + "average": [ + [None, 1.0, 2.0, 3.0], + [1.0, 2.0, 3.0, 4.0], + [None, 1.0, 2.0, 3.0], + [None, 1.0, 2.0, 3.0], + ], + } + + +def test_rank_string_null_11252() -> None: + rank = pl.Series([None, "", "z", None, "a"]).rank() + assert rank.to_list() == [None, 1.0, 3.0, None, 2.0] + + +def test_rank_series() -> None: + s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) + + assert_series_equal( + s.rank("dense"), pl.Series("a", [2, 3, 4, 3, 3, 4, 1], dtype=pl.UInt32) + ) + + df = pl.DataFrame([s]) + assert df.select(pl.col("a").rank("dense"))["a"].to_list() == [2, 3, 4, 3, 3, 4, 1] + + assert_series_equal( + s.rank("dense", descending=True), + pl.Series("a", [3, 2, 1, 2, 2, 1, 4], dtype=pl.UInt32), + ) + + assert s.rank(method="average").dtype == pl.Float64 + assert s.rank(method="max").dtype == pl.get_index_type() diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py index 03d1feb2681c..81edb16a6d49 100644 --- a/py-polars/tests/unit/operations/test_replace.py +++ b/py-polars/tests/unit/operations/test_replace.py @@ -281,3 +281,12 @@ def test_replace_default_deprecated() -> None: result = s.replace(1, 10, default=None) expected = pl.Series([10, None, None], dtype=pl.Int32) assert_series_equal(result, expected) + + +def test_replace_single_argument_not_mapping() -> None: + df = pl.DataFrame({"a": ["a", "b", "c"]}) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df.select(pl.col("a").replace("b")) diff --git a/py-polars/tests/unit/operations/test_replace_strict.py b/py-polars/tests/unit/operations/test_replace_strict.py index d72f0c7968d6..14f99585e64e 100644 --- a/py-polars/tests/unit/operations/test_replace_strict.py +++ b/py-polars/tests/unit/operations/test_replace_strict.py @@ -398,3 +398,12 @@ def test_replace_strict_cat_cat( s = pl.Series("s", ["a", "b"], dtype=dt) s_replaced = s.replace_strict(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +def test_replace_strict_single_argument_not_mapping() -> None: + df = pl.DataFrame({"a": ["b", "b", "b"]}) + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + df.select(pl.col("a").replace_strict("b")) diff --git a/py-polars/tests/unit/operations/test_search_sorted.py b/py-polars/tests/unit/operations/test_search_sorted.py new file mode 100644 index 000000000000..ff1fdef1c7b9 --- /dev/null +++ b/py-polars/tests/unit/operations/test_search_sorted.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +import polars as pl +from polars.testing import assert_series_equal + + +def test_search_sorted() -> None: + for seed in [1, 2, 3]: + np.random.seed(seed) + arr = np.sort(np.random.randn(10) * 100) + s = pl.Series(arr) + + for v in range(int(np.min(arr)), int(np.max(arr)), 20): + assert np.searchsorted(arr, v) == s.search_sorted(v) + + a = pl.Series([1, 2, 3]) + b = pl.Series([1, 2, 2, -1]) + assert a.search_sorted(b).to_list() == [0, 1, 1, 0] + b = pl.Series([1, 2, 2, None, 3]) + assert a.search_sorted(b).to_list() == [0, 1, 1, 0, 2] + + a = pl.Series(["b", "b", "d", "d"]) + b = pl.Series(["a", "b", "c", "d", "e"]) + assert a.search_sorted(b, side="left").to_list() == [0, 0, 2, 2, 4] + assert a.search_sorted(b, side="right").to_list() == [0, 2, 2, 4, 4] + + a = pl.Series([1, 1, 4, 4]) + b = pl.Series([0, 1, 2, 4, 5]) + assert a.search_sorted(b, side="left").to_list() == [0, 0, 2, 2, 4] + assert a.search_sorted(b, side="right").to_list() == [0, 2, 2, 4, 4] + + +def test_search_sorted_multichunk() -> None: + for seed in [1, 2, 3]: + np.random.seed(seed) + arr = np.sort(np.random.randn(10) * 100) + q = len(arr) // 4 + a, b, c, d = map( + pl.Series, (arr[:q], arr[q : 2 * q], arr[2 * q : 3 * q], arr[3 * q :]) + ) + s = pl.concat([a, b, c, d], rechunk=False) + assert s.n_chunks() == 4 + + for v in range(int(np.min(arr)), int(np.max(arr)), 20): + assert np.searchsorted(arr, v) == s.search_sorted(v) + + a = pl.concat( + [ + pl.Series([None, None, None], dtype=pl.Int64), + pl.Series([None, 1, 1, 2, 3]), + pl.Series([4, 4, 5, 6, 7, 8, 8]), + ], + rechunk=False, + ) + assert a.n_chunks() == 3 + b = pl.Series([-10, 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, None]) + left_ref = pl.Series( + [4, 4, 4, 6, 7, 8, 10, 11, 12, 13, 15, 0], dtype=pl.get_index_type() + ) + right_ref = pl.Series( + [4, 4, 6, 7, 8, 10, 11, 12, 13, 15, 15, 4], dtype=pl.get_index_type() + ) + assert_series_equal(a.search_sorted(b, side="left"), left_ref) + assert_series_equal(a.search_sorted(b, side="right"), right_ref) + + +def test_search_sorted_right_nulls() -> None: + a = pl.Series([1, 2, None, None]) + assert a.search_sorted(None, side="left") == 2 + assert a.search_sorted(None, side="right") == 4 + + +def test_raise_literal_numeric_search_sorted_18096() -> None: + df = pl.DataFrame({"foo": [1, 4, 7], "bar": [2, 3, 5]}) + + with pytest.raises(pl.exceptions.InvalidOperationError): + df.with_columns(idx=pl.col("foo").search_sorted("bar")) diff --git a/py-polars/tests/unit/operations/test_slice.py b/py-polars/tests/unit/operations/test_slice.py index 7c8fb22665c1..94dc1e3283ff 100644 --- a/py-polars/tests/unit/operations/test_slice.py +++ b/py-polars/tests/unit/operations/test_slice.py @@ -243,3 +243,48 @@ def test_double_sort_slice_pushdown_15779() -> None: assert ( pl.LazyFrame({"foo": [1, 2]}).sort("foo").head(0).sort("foo").collect() ).shape == (0, 1) + + +def test_slice_pushdown_simple_projection_18288() -> None: + lf = pl.DataFrame({"col": ["0", "notanumber"]}).lazy() + lf = lf.with_columns([pl.col("col").cast(pl.Int64)]) + lf = lf.with_columns([pl.col("col"), pl.lit(None)]) + assert lf.head(1).collect().to_dict(as_series=False) == { + "col": [0], + "literal": [None], + } + + +def test_group_by_slice_all_keys() -> None: + df = pl.DataFrame( + { + "a": ["Tom", "Nick", "Marry", "Krish", "Jack", None], + "b": [ + "2020-01-01", + "2020-01-02", + "2020-01-03", + "2020-01-04", + "2020-01-05", + None, + ], + "c": [5, 6, 6, 7, 8, 5], + } + ) + + gb = df.group_by(["a", "b", "c"], maintain_order=True) + assert_frame_equal(gb.tail(1), gb.head(1)) + + +def test_slice_first_in_agg_18551() -> None: + df = pl.DataFrame({"id": [1, 1, 2], "name": ["A", "B", "C"], "value": [31, 21, 32]}) + + assert df.group_by("id", maintain_order=True).agg( + sort_by=pl.col("name").sort_by("value"), + x=pl.col("name").sort_by("value").slice(0, 1).first(), + y=pl.col("name").sort_by("value").slice(1, 1).first(), + ).to_dict(as_series=False) == { + "id": [1, 2], + "sort_by": [["B", "A"], ["C"]], + "x": ["B", "C"], + "y": ["A", None], + } diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 6119797dfe08..6d67e83a418a 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -507,7 +507,7 @@ def test_sort_type_coercion_6892() -> None: } -@pytest.mark.slow() +@pytest.mark.slow def test_sort_row_fmt(str_ints_df: pl.DataFrame) -> None: # we sort nulls_last as this will always dispatch # to row_fmt and is the default in pandas @@ -736,7 +736,7 @@ def test_sort_descending_nulls_last(descending: bool, nulls_last: bool) -> None: ) -@pytest.mark.release() +@pytest.mark.release def test_sort_nan_1942() -> None: # https://github.com/pola-rs/polars/issues/1942 import time diff --git a/py-polars/tests/unit/operations/test_statistics.py b/py-polars/tests/unit/operations/test_statistics.py index 8aa1b0ae6811..ed8b964582cb 100644 --- a/py-polars/tests/unit/operations/test_statistics.py +++ b/py-polars/tests/unit/operations/test_statistics.py @@ -7,6 +7,7 @@ import pytest import polars as pl +from polars import StringCache from polars.testing import assert_frame_equal @@ -37,16 +38,21 @@ def test_corr_nan() -> None: assert str(df.select(pl.corr("a", "b", ddof=1))[0, 0]) == "nan" +@StringCache() def test_hist() -> None: - a = pl.Series("a", [1, 3, 8, 8, 2, 1, 3]) - assert ( - str(a.hist(bin_count=4).to_dict(as_series=False)) - == "{'breakpoint': [0.0, 2.25, 4.5, 6.75, inf], 'category': ['(-inf, 0.0]', '(0.0, 2.25]', '(2.25, 4.5]', '(4.5, 6.75]', '(6.75, inf]'], 'count': [0, 3, 2, 0, 2]}" + s = pl.Series("a", [1, 3, 8, 8, 2, 1, 3]) + out = s.hist(bin_count=4) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.75, 4.5, 6.25, 8.0], dtype=pl.Float64), + "category": pl.Series( + ["(0.993, 2.75]", "(2.75, 4.5]", "(4.5, 6.25]", "(6.25, 8.0]"], + dtype=pl.Categorical, + ), + "count": pl.Series([3, 2, 0, 2], dtype=pl.get_index_type()), + } ) - - assert a.hist( - bins=[0, 2], include_category=False, include_breakpoint=False - ).to_series().to_list() == [0, 3, 4] + assert_frame_equal(out, expected, categorical_as_str=True) @pytest.mark.parametrize("values", [[], [None]]) diff --git a/py-polars/tests/unit/operations/test_unpivot.py b/py-polars/tests/unit/operations/test_unpivot.py index a4155da56874..7b51d91122dc 100644 --- a/py-polars/tests/unit/operations/test_unpivot.py +++ b/py-polars/tests/unit/operations/test_unpivot.py @@ -92,3 +92,9 @@ def test_unpivot_raise_list() -> None: pl.LazyFrame( {"a": ["x", "y"], "b": [["test", "test2"], ["test3", "test4"]]} ).unpivot().collect() + + +def test_unpivot_empty_18170() -> None: + assert pl.DataFrame().unpivot().schema == pl.Schema( + {"variable": pl.String(), "value": pl.Null()} + ) diff --git a/py-polars/tests/unit/operations/test_window.py b/py-polars/tests/unit/operations/test_window.py index 17df33d268dc..ddf293d216c9 100644 --- a/py-polars/tests/unit/operations/test_window.py +++ b/py-polars/tests/unit/operations/test_window.py @@ -454,7 +454,7 @@ def test_window_agg_list_null_15437() -> None: assert_frame_equal(output, expected) -@pytest.mark.release() +@pytest.mark.release def test_windows_not_cached() -> None: ldf = ( pl.DataFrame( diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 5f3ea2c91b09..3f7f159ccae0 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -994,25 +994,6 @@ def test_mode() -> None: assert pl.int_range(0, 3, eager=True).mode().to_list() == [2, 1, 0] -def test_rank() -> None: - s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) - - assert_series_equal( - s.rank("dense"), pl.Series("a", [2, 3, 4, 3, 3, 4, 1], dtype=UInt32) - ) - - df = pl.DataFrame([s]) - assert df.select(pl.col("a").rank("dense"))["a"].to_list() == [2, 3, 4, 3, 3, 4, 1] - - assert_series_equal( - s.rank("dense", descending=True), - pl.Series("a", [3, 2, 1, 2, 2, 1, 4], dtype=UInt32), - ) - - assert s.rank(method="average").dtype == pl.Float64 - assert s.rank(method="max").dtype == pl.get_index_type() - - def test_diff() -> None: s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0]) expected = pl.Series("a", [1, 1, -1, 0, 1, -3]) @@ -1766,8 +1747,8 @@ def test_sign() -> None: assert_series_equal(a.sign(), expected) # Floats - a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None]) - expected = pl.Series("a", [-1, 0, 0, 1, None]) + a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, float("nan"), None]) + expected = pl.Series("a", [-1.0, 0.0, 0.0, 1.0, float("nan"), None]) assert_series_equal(a.sign(), expected) # Invalid input @@ -2165,3 +2146,8 @@ def test_series_from_numpy_with_dtype() -> None: s = pl.Series("foo", np.array([-1, 2, 3]), dtype=pl.UInt8, strict=False) assert s.to_list() == [None, 2, 3] assert s.dtype == pl.UInt8 + + +def test_raise_invalid_is_between() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.select(pl.lit(2).is_between(pl.lit("11"), pl.lit("33"))) diff --git a/py-polars/tests/unit/sql/test_conditional.py b/py-polars/tests/unit/sql/test_conditional.py index 21c9d0dc738b..b2000ebe37b1 100644 --- a/py-polars/tests/unit/sql/test_conditional.py +++ b/py-polars/tests/unit/sql/test_conditional.py @@ -10,7 +10,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_functions.py b/py-polars/tests/unit/sql/test_functions.py index 04055eb01c03..84f2ecd972bc 100644 --- a/py-polars/tests/unit/sql/test_functions.py +++ b/py-polars/tests/unit/sql/test_functions.py @@ -9,7 +9,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 83e91e78ed5d..08e4b236c833 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -10,7 +10,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_joins.py b/py-polars/tests/unit/sql/test_joins.py index 97872d8bbdcc..3a1e90bed1b4 100644 --- a/py-polars/tests/unit/sql/test_joins.py +++ b/py-polars/tests/unit/sql/test_joins.py @@ -10,7 +10,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index 1ecd08e01e25..95ba8461bebe 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -2,6 +2,7 @@ from datetime import date from pathlib import Path +from typing import TYPE_CHECKING, Any import pytest @@ -9,8 +10,11 @@ from polars.exceptions import SQLInterfaceError, SQLSyntaxError from polars.testing import assert_frame_equal +if TYPE_CHECKING: + from polars.datatypes import DataType -@pytest.fixture() + +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" @@ -53,6 +57,28 @@ def test_any_all() -> None: } +@pytest.mark.parametrize( + ("data", "schema"), + [ + ({"x": [1, 2, 3, 4]}, None), + ({"x": [9, 8, 7, 6]}, {"x": pl.Int8}), + ({"x": ["aa", "bb"]}, {"x": pl.Struct}), + ({"x": [None, None], "y": [None, None]}, {"x": pl.Date, "y": pl.Float64}), + ], +) +def test_boolean_where_clauses( + data: dict[str, Any], schema: dict[str, DataType] | None +) -> None: + df = pl.DataFrame(data=data, schema=schema) + empty_df = df.clear() + + for true in ("TRUE", "1=1", "2 == 2", "'xx' = 'xx'", "TRUE AND 1=1"): + assert_frame_equal(df, df.sql(f"SELECT * FROM self WHERE {true}")) + + for false in ("false", "1!=1", "2 != 2", "'xx' != 'xx'", "FALSE OR 1!=1"): + assert_frame_equal(empty_df, df.sql(f"SELECT * FROM self WHERE {false}")) + + def test_count() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/sql/test_operators.py b/py-polars/tests/unit/sql/test_operators.py index 668ead0bc087..e42825259b34 100644 --- a/py-polars/tests/unit/sql/test_operators.py +++ b/py-polars/tests/unit/sql/test_operators.py @@ -9,7 +9,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" @@ -37,7 +37,7 @@ def test_div() -> None: [ [-0.0995024875621891, 2.85714285714286, 12.0, None, -15.92356687898089], [-1, 2, 12, None, -16], - [-1, 1, 1, None, -1], + [-1.0, 1.0, 1.0, None, -1.0], ], schema=["a_div_b", "a_floordiv_b", "b_sign"], ), diff --git a/py-polars/tests/unit/sql/test_order_by.py b/py-polars/tests/unit/sql/test_order_by.py index bd93743a416a..704ebd3e773e 100644 --- a/py-polars/tests/unit/sql/test_order_by.py +++ b/py-polars/tests/unit/sql/test_order_by.py @@ -8,7 +8,7 @@ from polars.exceptions import SQLInterfaceError, SQLSyntaxError -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_regex.py b/py-polars/tests/unit/sql/test_regex.py index ece4e02bd8fe..af0b9a31fa67 100644 --- a/py-polars/tests/unit/sql/test_regex.py +++ b/py-polars/tests/unit/sql/test_regex.py @@ -8,7 +8,7 @@ from polars.exceptions import SQLSyntaxError -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_set_ops.py b/py-polars/tests/unit/sql/test_set_ops.py index 64508887d1c5..f148d561c31b 100644 --- a/py-polars/tests/unit/sql/test_set_ops.py +++ b/py-polars/tests/unit/sql/test_set_ops.py @@ -69,6 +69,26 @@ def test_except_intersect_by_name() -> None: assert res_i.columns == ["x", "y", "z"] +@pytest.mark.parametrize( + ("op", "op_subtype"), + [ + ("EXCEPT", "ALL"), + ("EXCEPT", "ALL BY NAME"), + ("INTERSECT", "ALL"), + ("INTERSECT", "ALL BY NAME"), + ], +) +def test_except_intersect_all_unsupported(op: str, op_subtype: str) -> None: + df1 = pl.DataFrame({"n": [1, 1, 1, 2, 2, 2, 3]}) # noqa: F841 + df2 = pl.DataFrame({"n": [1, 1, 2, 2]}) # noqa: F841 + + with pytest.raises( + SQLInterfaceError, + match=f"'{op} {op_subtype}' is not supported", + ): + pl.sql(f"SELECT * FROM df1 {op} {op_subtype} SELECT * FROM df2") + + @pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"]) def test_except_intersect_errors(op: str) -> None: df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841 diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index c0c4883f0dda..46e3c645a85a 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -10,7 +10,7 @@ # TODO: Do not rely on I/O for these tests -@pytest.fixture() +@pytest.fixture def foods_ipc_path() -> Path: return Path(__file__).parent.parent / "io" / "files" / "foods1.ipc" diff --git a/py-polars/tests/unit/sql/test_structs.py b/py-polars/tests/unit/sql/test_structs.py index fdd404023713..ce324172502b 100644 --- a/py-polars/tests/unit/sql/test_structs.py +++ b/py-polars/tests/unit/sql/test_structs.py @@ -7,7 +7,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def df_struct() -> pl.DataFrame: return pl.DataFrame( { diff --git a/py-polars/tests/unit/sql/test_table_operations.py b/py-polars/tests/unit/sql/test_table_operations.py index 73b544c9c0ed..8c3862b85b35 100644 --- a/py-polars/tests/unit/sql/test_table_operations.py +++ b/py-polars/tests/unit/sql/test_table_operations.py @@ -10,7 +10,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def test_frame() -> pl.LazyFrame: return pl.LazyFrame( { diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py index d1fd6873bc3c..e27ce9ac14b3 100644 --- a/py-polars/tests/unit/sql/test_wildcard_opts.py +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -9,7 +9,7 @@ from polars.testing import assert_frame_equal -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: return pl.DataFrame( { diff --git a/py-polars/tests/unit/streaming/conftest.py b/py-polars/tests/unit/streaming/conftest.py index b7b476474316..d3c434edcc3e 100644 --- a/py-polars/tests/unit/streaming/conftest.py +++ b/py-polars/tests/unit/streaming/conftest.py @@ -3,6 +3,6 @@ import pytest -@pytest.fixture() +@pytest.fixture def io_files_path() -> Path: return Path(__file__).parent.parent / "io" / "files" diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index 9c8ae9a2d57b..80730273672e 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -53,6 +53,7 @@ def test_streaming_block_on_literals_6054() -> None: ).sort("col_1").to_dict(as_series=False) == {"col_1": [0, 1], "col_2": [0, 5]} +@pytest.mark.may_fail_auto_streaming def test_streaming_streamable_functions(monkeypatch: Any, capfd: Any) -> None: monkeypatch.setenv("POLARS_VERBOSE", "1") assert ( @@ -72,7 +73,7 @@ def test_streaming_streamable_functions(monkeypatch: Any, capfd: Any) -> None: assert "df -> function -> ordered_sink" in err -@pytest.mark.slow() +@pytest.mark.slow def test_cross_join_stack() -> None: a = pl.Series(np.arange(100_000)).to_frame().lazy() t0 = time.time() @@ -115,6 +116,7 @@ def test_streaming_literal_expansion() -> None: } +@pytest.mark.may_fail_auto_streaming def test_streaming_apply(monkeypatch: Any, capfd: Any) -> None: monkeypatch.setenv("POLARS_VERBOSE", "1") @@ -161,8 +163,8 @@ def test_streaming_sortedness_propagation_9494() -> None: } -@pytest.mark.write_disk() -@pytest.mark.slow() +@pytest.mark.write_disk +@pytest.mark.slow def test_streaming_generic_left_and_inner_join_from_disk(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) p0 = tmp_path / "df0.parquet" @@ -212,7 +214,7 @@ def test_streaming_9776() -> None: assert unordered.sort(["col_1", "ID"]).rows() == expected -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_stream_empty_file(tmp_path: Path) -> None: p = tmp_path / "in.parquet" schema = { @@ -288,7 +290,7 @@ def test_boolean_agg_schema() -> None: ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_csv_headers_but_no_data_13770(tmp_path: Path) -> None: with Path.open(tmp_path / "header_no_data.csv", "w") as f: f.write("name, age\n") @@ -303,7 +305,7 @@ def test_streaming_csv_headers_but_no_data_13770(tmp_path: Path) -> None: assert df.schema == schema -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_custom_temp_dir(tmp_path: Path, monkeypatch: Any) -> None: tmp_path.mkdir(exist_ok=True) monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) @@ -317,7 +319,7 @@ def test_custom_temp_dir(tmp_path: Path, monkeypatch: Any) -> None: assert os.listdir(tmp_path), f"Temp directory '{tmp_path}' is empty" -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_with_hconcat(tmp_path: Path) -> None: df1 = pl.DataFrame( { diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index de36443fc3a5..561783f833a8 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -17,7 +17,7 @@ pytestmark = pytest.mark.xdist_group("streaming") -@pytest.mark.slow() +@pytest.mark.slow def test_streaming_group_by_sorted_fast_path_nulls_10273() -> None: df = pl.Series( name="x", @@ -207,7 +207,7 @@ def random_integers() -> pl.Series: return pl.Series("a", np.random.randint(0, 10, 100), dtype=pl.Int64) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_group_by_ooc_q1( random_integers: pl.Series, tmp_path: Path, @@ -235,7 +235,7 @@ def test_streaming_group_by_ooc_q1( assert_frame_equal(result, expected) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_group_by_ooc_q2( random_integers: pl.Series, tmp_path: Path, @@ -263,7 +263,7 @@ def test_streaming_group_by_ooc_q2( assert_frame_equal(result, expected) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_group_by_ooc_q3( random_integers: pl.Series, tmp_path: Path, @@ -306,7 +306,7 @@ def test_streaming_group_by_struct_key() -> None: } -@pytest.mark.slow() +@pytest.mark.slow def test_streaming_group_by_all_numeric_types_stability_8570() -> None: m = 1000 n = 1000 diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index 3c1ec7fb2469..0cbf0d90e4ba 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -14,7 +15,7 @@ pytestmark = pytest.mark.xdist_group("streaming") -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_parquet_glob_5900(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) file_path = tmp_path / "small.parquet" @@ -47,7 +48,7 @@ def test_scan_csv_overwrite_small_dtypes( assert df.dtypes == [pl.String, pl.Int64, pl.Float64, dtype] -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_parquet(io_files_path: Path, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -64,7 +65,7 @@ def test_sink_parquet(io_files_path: Path, tmp_path: Path) -> None: assert_frame_equal(result, df_read) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_parquet_10115(tmp_path: Path) -> None: in_path = tmp_path / "in.parquet" out_path = tmp_path / "out.parquet" @@ -89,7 +90,7 @@ def test_sink_parquet_10115(tmp_path: Path) -> None: } -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_ipc(io_files_path: Path, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -106,7 +107,7 @@ def test_sink_ipc(io_files_path: Path, tmp_path: Path) -> None: assert_frame_equal(result, df_read) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_csv(io_files_path: Path, tmp_path: Path) -> None: source_file = io_files_path / "small.parquet" target_file = tmp_path / "sink.csv" @@ -119,7 +120,7 @@ def test_sink_csv(io_files_path: Path, tmp_path: Path) -> None: assert_frame_equal(target_data, source_data) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_csv_14494(tmp_path: Path) -> None: pl.LazyFrame({"c": [1, 2, 3]}, schema={"c": pl.Int64}).filter( pl.col("c") > 10 @@ -194,7 +195,7 @@ def test_sink_csv_batch_size_zero() -> None: lf.sink_csv("test.csv", batch_size=0) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_csv_nested_data(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) path = tmp_path / "data.csv" @@ -218,7 +219,7 @@ def test_scan_empty_csv_10818(io_files_path: Path) -> None: assert df.is_empty() -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_cross_join_schema(tmp_path: Path) -> None: file_path = tmp_path / "temp.parquet" a = pl.DataFrame({"a": [1, 2]}).lazy() @@ -228,7 +229,7 @@ def test_streaming_cross_join_schema(tmp_path: Path) -> None: assert read.to_dict(as_series=False) == {"a": [1, 2], "b": ["b", "b"]} -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_sink_ndjson_should_write_same_data( io_files_path: Path, tmp_path: Path ) -> None: @@ -246,7 +247,7 @@ def test_sink_ndjson_should_write_same_data( assert_frame_equal(df, expected) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_parquet_eq_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -286,7 +287,7 @@ def test_parquet_eq_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_empty_parquet_16523(tmp_path: Path) -> None: file_path = tmp_path / "foo.parquet" df = pl.DataFrame({"a": []}, schema={"a": pl.Int32}) @@ -294,3 +295,26 @@ def test_streaming_empty_parquet_16523(tmp_path: Path) -> None: q = pl.scan_parquet(file_path) q2 = pl.LazyFrame({"a": [1]}, schema={"a": pl.Int32}) assert q.join(q2, on="a").collect(streaming=True).shape == (0, 1) + + +@pytest.mark.parametrize( + "method", + ["parquet", "csv"], +) +def test_nyi_scan_in_memory(method: str) -> None: + f = io.BytesIO() + df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": ["x", "y", "z"], + } + ) + + (getattr(df, f"write_{method}"))(f) + + f.seek(0) + with pytest.raises( + pl.exceptions.ComputeError, + match="not yet implemented: Streaming scanning of in-memory buffers", + ): + (getattr(pl, f"scan_{method}"))(f).collect(streaming=True) diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index 5edc7456f7d6..074110076b28 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -269,7 +269,7 @@ def test_non_coalescing_streaming_left_join() -> None: } -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_streaming_outer_join_partial_flush(tmp_path: Path) -> None: data = { "value_at": [datetime(2024, i + 1, 1) for i in range(6)], diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py index 4973b51fa623..8388671bd717 100644 --- a/py-polars/tests/unit/streaming/test_streaming_sort.py +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -73,8 +73,8 @@ def test_streaming_sort_multiple_columns_logical_types() -> None: assert_frame_equal(result, expected) -@pytest.mark.write_disk() -@pytest.mark.slow() +@pytest.mark.write_disk +@pytest.mark.slow def test_ooc_sort(tmp_path: Path, monkeypatch: Any) -> None: tmp_path.mkdir(exist_ok=True) monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) @@ -92,8 +92,8 @@ def test_ooc_sort(tmp_path: Path, monkeypatch: Any) -> None: assert_series_equal(out, s.sort(descending=descending)) -@pytest.mark.debug() -@pytest.mark.write_disk() +@pytest.mark.debug +@pytest.mark.write_disk @pytest.mark.parametrize("spill_source", [True, False]) def test_streaming_sort( tmp_path: Path, monkeypatch: Any, capfd: Any, spill_source: bool @@ -119,7 +119,7 @@ def test_streaming_sort( assert "PARTITIONED FORCE SPILLED" in err -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize("spill_source", [True, False]) def test_out_of_core_sort_9503( tmp_path: Path, monkeypatch: Any, spill_source: bool @@ -183,8 +183,8 @@ def test_out_of_core_sort_9503( } -@pytest.mark.write_disk() -@pytest.mark.slow() +@pytest.mark.write_disk +@pytest.mark.slow def test_streaming_sort_multiple_columns( str_ints_df: pl.DataFrame, tmp_path: Path, monkeypatch: Any, capfd: Any ) -> None: diff --git a/py-polars/tests/unit/streaming/test_streaming_unique.py b/py-polars/tests/unit/streaming/test_streaming_unique.py index 77d7534548dd..e477952a8c0b 100644 --- a/py-polars/tests/unit/streaming/test_streaming_unique.py +++ b/py-polars/tests/unit/streaming/test_streaming_unique.py @@ -13,8 +13,8 @@ pytestmark = pytest.mark.xdist_group("streaming") -@pytest.mark.write_disk() -@pytest.mark.slow() +@pytest.mark.write_disk +@pytest.mark.slow def test_streaming_out_of_core_unique( io_files_path: Path, tmp_path: Path, monkeypatch: Any, capfd: Any ) -> None: diff --git a/py-polars/tests/unit/test_api.py b/py-polars/tests/unit/test_api.py index 57bd08a6cc87..a90bc5ceb68d 100644 --- a/py-polars/tests/unit/test_api.py +++ b/py-polars/tests/unit/test_api.py @@ -124,7 +124,7 @@ def square(self) -> pl.Series: ] -@pytest.mark.slow() +@pytest.mark.slow @pytest.mark.parametrize("pcls", [pl.Expr, pl.DataFrame, pl.LazyFrame, pl.Series]) def test_class_namespaces_are_registered(pcls: Any) -> None: # confirm that existing (and new) namespaces diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py index d193ed6435d9..e4f1a152ad7b 100644 --- a/py-polars/tests/unit/test_config.py +++ b/py-polars/tests/unit/test_config.py @@ -610,7 +610,7 @@ def test_numeric_right_alignment() -> None: ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_config_load_save(tmp_path: Path) -> None: for file in ( None, @@ -688,6 +688,24 @@ def test_config_load_save(tmp_path: Path) -> None: assert isinstance(repr(pl.DataFrame({"xyz": [0]})), str) +def test_config_load_save_context() -> None: + # store the default configuration state + default_state = pl.Config.save() + + # establish some non-default settings + pl.Config.set_tbl_formatting("ASCII_MARKDOWN") + pl.Config.set_verbose(True) + + # load the default config, validate load & context manager behaviour + with pl.Config.load(default_state): + assert os.environ.get("POLARS_FMT_TABLE_FORMATTING") is None + assert os.environ.get("POLARS_VERBOSE") is None + + # ensure earlier state was restored + assert os.environ["POLARS_FMT_TABLE_FORMATTING"] == "ASCII_MARKDOWN" + assert os.environ["POLARS_VERBOSE"] + + def test_config_scope() -> None: pl.Config.set_verbose(False) pl.Config.set_tbl_cols(8) diff --git a/py-polars/tests/unit/test_cpu_check.py b/py-polars/tests/unit/test_cpu_check.py index 23525f5126dd..fdfa5965f6ff 100644 --- a/py-polars/tests/unit/test_cpu_check.py +++ b/py-polars/tests/unit/test_cpu_check.py @@ -6,7 +6,7 @@ from polars._cpu_check import check_cpu_flags -@pytest.fixture() +@pytest.fixture def _feature_flags(monkeypatch: pytest.MonkeyPatch) -> None: """Use the default set of feature flags.""" feature_flags = "+sse3,+ssse3" diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index 7a46e73e8c36..c4b6466e51b7 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -145,7 +145,7 @@ def test_cse_9630() -> None: assert_frame_equal(result, expected) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_schema_row_index_cse() -> None: csv_a = NamedTemporaryFile() csv_a.write( @@ -181,7 +181,7 @@ def test_schema_row_index_cse() -> None: assert_frame_equal(result, expected) -@pytest.mark.debug() +@pytest.mark.debug def test_cse_expr_selection_context() -> None: q = pl.LazyFrame( { @@ -634,7 +634,7 @@ def test_cse_15548() -> None: assert len(ldf3.collect(comm_subplan_elim=True)) == 4 -@pytest.mark.debug() +@pytest.mark.debug def test_cse_and_schema_update_projection_pd() -> None: df = pl.LazyFrame({"a": [1, 2], "b": [99, 99]}) @@ -654,7 +654,7 @@ def test_cse_and_schema_update_projection_pd() -> None: assert num_cse_occurrences(q.explain(comm_subexpr_elim=True)) == 1 -@pytest.mark.debug() +@pytest.mark.debug def test_cse_predicate_self_join(capfd: Any, monkeypatch: Any) -> None: monkeypatch.setenv("POLARS_VERBOSE", "1") y = pl.LazyFrame({"a": [1], "b": [2], "y": [3]}) @@ -702,7 +702,7 @@ def test_cse_no_projection_15980() -> None: ) == {"x": ["a", "a"]} -@pytest.mark.debug() +@pytest.mark.debug def test_cse_series_collision_16138() -> None: holdings = pl.DataFrame( { @@ -765,3 +765,33 @@ def test_cse_non_scalar_length_mismatch_17732() -> None: ) assert_frame_equal(expect, got) + + +def test_cse_chunks_18124() -> None: + df = pl.DataFrame( + { + "ts_diff": [timedelta(seconds=60)] * 2, + "ts_diff_after": [timedelta(seconds=120)] * 2, + } + ) + df = pl.concat([df, df], rechunk=False) + assert ( + df.lazy() + .with_columns( + ts_diff_sign=pl.col("ts_diff") > pl.duration(seconds=0), + ts_diff_after_sign=pl.col("ts_diff_after") > pl.duration(seconds=0), + ) + .filter(pl.col("ts_diff") > 1) + ).collect().shape == (4, 4) + + +def test_eager_cse_during_struct_expansion_18411() -> None: + df = pl.DataFrame({"foo": [0, 0, 0, 1, 1]}) + vc = pl.col("foo").value_counts() + classes = vc.struct[0] + counts = vc.struct[1] + # Check if output is stable + assert ( + df.select(pl.col("foo").replace(classes, counts)) + == df.select(pl.col("foo").replace(classes, counts)) + )["foo"].all() diff --git a/py-polars/tests/unit/test_datatypes.py b/py-polars/tests/unit/test_datatypes.py index 8313c4203c1f..9bd545125f64 100644 --- a/py-polars/tests/unit/test_datatypes.py +++ b/py-polars/tests/unit/test_datatypes.py @@ -196,3 +196,8 @@ def test_struct_field_iter() -> None: ("b", List(Int64)), ("a", List(List(Int64))), ] + + +def test_raise_invalid_namespace() -> None: + with pytest.raises(pl.exceptions.InvalidOperationError): + pl.select(pl.lit(1.5).str.replace("1", "2")) diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 8d632bbd5a36..07b98d9d8111 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -1,7 +1,7 @@ from __future__ import annotations import io -from datetime import date, datetime, time +from datetime import date, datetime, time, tzinfo from decimal import Decimal from typing import TYPE_CHECKING, Any @@ -326,10 +326,16 @@ def test_datetime_time_add_err() -> None: def test_invalid_dtype() -> None: with pytest.raises( TypeError, - match="cannot parse input of type 'str' into Polars data type: 'mayonnaise'", + match=r"cannot parse input of type 'str' into Polars data type \(given: 'mayonnaise'\)", ): pl.Series([1, 2], dtype="mayonnaise") # type: ignore[arg-type] + with pytest.raises( + TypeError, + match="cannot parse input into Polars data type", + ): + pl.Series([None], dtype=tzinfo) # type: ignore[arg-type] + def test_arr_eval_named_cols() -> None: df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]}) @@ -343,7 +349,7 @@ def test_arr_eval_named_cols() -> None: def test_alias_in_join_keys() -> None: df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]}) with pytest.raises( - ComputeError, + InvalidOperationError, match=r"'alias' is not allowed in a join key, use 'with_columns' first", ): df.join(df, on=pl.col("A").alias("foo")) @@ -484,7 +490,7 @@ def test_skip_nulls_err() -> None: with pytest.raises( ComputeError, - match=r"The output type of the 'apply' function cannot be determined", + match=r"The output type of the 'map_elements' function cannot be determined", ): df.with_columns(pl.col("foo").map_elements(lambda x: x, skip_nulls=True)) @@ -587,7 +593,7 @@ def test_invalid_is_in_dtypes( if expected is None: with pytest.raises( InvalidOperationError, - match="`is_in` cannot check for .*? values in .*? data", + match="'is_in' cannot check for .*? values in .*? data", ): df.select(pl.col(colname).is_in(values)) else: @@ -650,7 +656,7 @@ def test_invalid_product_type() -> None: def test_fill_null_invalid_supertype() -> None: df = pl.DataFrame({"date": [date(2022, 1, 1), None]}) - with pytest.raises(InvalidOperationError, match="could not determine supertype of"): + with pytest.raises(InvalidOperationError, match="got invalid or ambiguous"): df.select(pl.col("date").fill_null(1.0)) diff --git a/py-polars/tests/unit/test_plugins.py b/py-polars/tests/unit/test_plugins.py index 828ae8047306..60d14852a0d5 100644 --- a/py-polars/tests/unit/test_plugins.py +++ b/py-polars/tests/unit/test_plugins.py @@ -15,7 +15,7 @@ ) -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_register_plugin_function_invalid_plugin_path(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) plugin_path = tmp_path / "lib.so" @@ -44,7 +44,7 @@ def test_serialize_kwargs(input: dict[str, Any] | None, expected: bytes) -> None assert _serialize_kwargs(input) == expected -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_resolve_plugin_path(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -63,7 +63,7 @@ def test_resolve_plugin_path(tmp_path: Path) -> None: assert result == expected -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_resolve_plugin_path_raises(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) (tmp_path / "__init__.py").touch() @@ -72,7 +72,7 @@ def test_resolve_plugin_path_raises(tmp_path: Path) -> None: _resolve_plugin_path(tmp_path) -@pytest.mark.write_disk() +@pytest.mark.write_disk @pytest.mark.parametrize( ("path", "expected"), [ @@ -89,7 +89,7 @@ def test_is_dynamic_lib(path: Path, expected: bool, tmp_path: Path) -> None: assert _is_dynamic_lib(full_path) is expected -@pytest.mark.write_disk() +@pytest.mark.write_disk def test_is_dynamic_lib_dir(tmp_path: Path) -> None: path = Path("lib.so") full_path = tmp_path / path diff --git a/py-polars/tests/unit/test_polars_import.py b/py-polars/tests/unit/test_polars_import.py index ed9bad5a4a7a..fa1779de3478 100644 --- a/py-polars/tests/unit/test_polars_import.py +++ b/py-polars/tests/unit/test_polars_import.py @@ -67,7 +67,7 @@ def _import_timings_as_frame(n_tries: int) -> tuple[pl.DataFrame, int]: @pytest.mark.skipif(sys.platform == "win32", reason="Unreliable on Windows") -@pytest.mark.slow() +@pytest.mark.slow def test_polars_import() -> None: # up-front compile '.py' -> '.pyc' before timing polars_path = Path(pl.__file__).parent diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 9586bfb0a2ae..700100ced4c4 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -78,11 +78,19 @@ def test_unnest_projection_pushdown() -> None: pl.col("field_2").cast(pl.Categorical).alias("col"), pl.col("value"), ) - out = mlf.collect().to_dict(as_series=False) + + out = ( + mlf.sort( + [pl.col.row.cast(pl.String), pl.col.col.cast(pl.String)], + maintain_order=True, + ) + .collect() + .to_dict(as_series=False) + ) assert out == { - "row": ["y", "y", "b", "b"], - "col": ["z", "z", "c", "c"], - "value": [1, 2, 2, 3], + "row": ["b", "b", "y", "y"], + "col": ["c", "c", "z", "z"], + "value": [2, 3, 1, 2], } diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 5ddf46531840..4c50a183af49 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -241,8 +241,8 @@ def map_expr(name: str) -> pl.Expr: ).to_dict(as_series=False) == { "groups": [1, 2, 3, 4], "out": [ - {"sum": None, "count": None}, - {"sum": None, "count": None}, + None, + None, {"sum": 1, "count": 1}, {"sum": 2, "count": 1}, ], diff --git a/py-polars/tests/unit/test_rows.py b/py-polars/tests/unit/test_rows.py index 1a66f6dc7fb3..91d001a58c96 100644 --- a/py-polars/tests/unit/test_rows.py +++ b/py-polars/tests/unit/test_rows.py @@ -1,3 +1,5 @@ +from datetime import date + import pytest import polars as pl @@ -246,3 +248,23 @@ def test_row_constructor_uint64() -> None: data=[[0], [int(2**63) + 1]], schema={"x": pl.UInt64}, orient="row" ) assert df.rows() == [(0,), (9223372036854775809,)] + + +def test_physical_row_encoding() -> None: + dt_str = [ + { + "ts": date(2023, 7, 1), + "files": "AGG_202307.xlsx", + "period_bins": [date(2023, 7, 1), date(2024, 1, 1)], + }, + ] + + df = pl.from_dicts(dt_str) + df_groups = df.group_by("period_bins") + assert df_groups.all().to_dicts() == [ + { + "period_bins": [date(2023, 7, 1), date(2024, 1, 1)], + "ts": [date(2023, 7, 1)], + "files": ["AGG_202307.xlsx"], + } + ] diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 8ccb4497ac0f..4c09c986eeb9 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,4 +1,5 @@ import pickle +from datetime import datetime import polars as pl @@ -14,20 +15,29 @@ def test_schema() -> None: def test_schema_parse_nonpolars_dtypes() -> None: - # Currently, no parsing is being done. - s = pl.Schema({"foo": pl.List, "bar": int}) # type: ignore[arg-type] + cardinal_directions = pl.Enum(["north", "south", "east", "west"]) + + s = pl.Schema({"foo": pl.List, "bar": int, "baz": cardinal_directions}) # type: ignore[arg-type] + s["ham"] = datetime assert s["foo"] == pl.List - assert s["bar"] is int - assert s.len() == 2 - assert s.names() == ["foo", "bar"] - assert s.dtypes() == [pl.List, int] + assert s["bar"] == pl.Int64 + assert s["baz"] == cardinal_directions + assert s["ham"] == pl.Datetime("us") + + assert s.len() == 4 + assert s.names() == ["foo", "bar", "baz", "ham"] + assert s.dtypes() == [pl.List, pl.Int64, cardinal_directions, pl.Datetime("us")] + + assert list(s.to_python().values()) == [list, int, str, datetime] + assert [tp.to_python() for tp in s.dtypes()] == [list, int, str, datetime] def test_schema_equality() -> None: s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) s3 = pl.Schema({"bar": pl.Float64(), "foo": pl.Int8()}) + assert s1 == s1 assert s2 == s2 assert s3 == s3 @@ -37,14 +47,38 @@ def test_schema_equality() -> None: def test_schema_picklable() -> None: - s = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) - + s = pl.Schema( + { + "foo": pl.Int8(), + "bar": pl.String(), + "ham": pl.Struct({"x": pl.List(pl.Date)}), + } + ) pickled = pickle.dumps(s) s2 = pickle.loads(pickled) - assert s == s2 +def test_schema_python() -> None: + input = { + "foo": pl.Int8(), + "bar": pl.String(), + "baz": pl.Categorical("lexical"), + "ham": pl.Object(), + "spam": pl.Struct({"time": pl.List(pl.Duration), "dist": pl.Float64}), + } + expected = { + "foo": int, + "bar": str, + "baz": str, + "ham": object, + "spam": dict, + } + for schema in (input, input.items(), list(input.items())): + s = pl.Schema(schema) + assert expected == s.to_python() + + def test_schema_in_map_elements_returns_scalar() -> None: schema = pl.Schema([("portfolio", pl.String()), ("irr", pl.Float64())]) @@ -62,6 +96,11 @@ def test_schema_in_map_elements_returns_scalar() -> None: ) .alias("irr") ) - assert (q.collect_schema()) == schema assert q.collect().schema == schema + + +def test_ir_cache_unique_18198() -> None: + lf = pl.LazyFrame({"a": [1]}) + lf.collect_schema() + assert pl.concat([lf, lf]).collect().to_dict(as_series=False) == {"a": [1, 1]} diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index e292e709a970..a5d573d10379 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -9,7 +9,7 @@ import polars.selectors as cs from polars._typing import SelectorType from polars.dependencies import _ZONEINFO_AVAILABLE -from polars.exceptions import ColumnNotFoundError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError from polars.selectors import expand_selector, is_selector from polars.testing import assert_frame_equal from tests.unit.conftest import INTEGER_DTYPES, TEMPORAL_DTYPES @@ -30,7 +30,7 @@ def assert_repr_equals(item: Any, expected: str) -> None: assert repr(item) == expected -@pytest.fixture() +@pytest.fixture def df() -> pl.DataFrame: # set up an empty dataframe with plenty of columns of various dtypes df = pl.DataFrame( @@ -809,3 +809,16 @@ def test_selector_result_order(df: pl.DataFrame, selector: SelectorType) -> None "qqR": pl.String, } ) + + +def test_selector_list_of_lists_18499() -> None: + lf = pl.DataFrame( + { + "foo": [1, 2, 3, 1], + "bar": ["a", "a", "a", "a"], + "ham": ["b", "b", "b", "b"], + } + ) + + with pytest.raises(InvalidOperationError, match="invalid selector expression"): + lf.unique(subset=[["bar", "ham"]]) # type: ignore[list-item] diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 869e75e3bf64..c1a79025690e 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io import pickle from datetime import datetime, timedelta @@ -207,3 +208,12 @@ def test_serde_data_type_instantiated_with_attributes() -> None: deserialized = pickle.loads(serialized) assert deserialized == dtype assert isinstance(deserialized, pl.DataType) + + +def test_serde_udf() -> None: + lf = pl.LazyFrame({"a": [[1, 2], [3, 4, 5]], "b": [3, 4]}).select( + pl.col("a").map_elements(lambda x: sum(x), return_dtype=pl.Int32) + ) + result = pl.LazyFrame.deserialize(io.BytesIO(lf.serialize())) + + assert_frame_equal(lf, result) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_core.py b/py-polars/tests/unit/testing/parametric/strategies/test_core.py index bb54255f4d10..8edf018e6f21 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_core.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_core.py @@ -170,13 +170,13 @@ def test_dataframes_columns(lf: pl.LazyFrame) -> None: assert all(v in xyz for v in df["d"].to_list()) -@pytest.mark.hypothesis() +@pytest.mark.hypothesis def test_column_invalid_probability() -> None: with pytest.deprecated_call(), pytest.raises(InvalidArgument): column("col", null_probability=2.0) -@pytest.mark.hypothesis() +@pytest.mark.hypothesis def test_column_null_probability_deprecated() -> None: with pytest.deprecated_call(): col = column("col", allow_null=False, null_probability=0.5) diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py b/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py index eab2574dc2d6..d9cd5bc5bbb7 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_legacy.py @@ -5,14 +5,14 @@ from polars.testing.parametric.strategies.core import _COL_LIMIT -@pytest.mark.hypothesis() +@pytest.mark.hypothesis def test_columns_deprecated() -> None: with pytest.deprecated_call(), pytest.warns(NonInteractiveExampleWarning): result = columns() assert 0 <= len(result) <= _COL_LIMIT -@pytest.mark.hypothesis() +@pytest.mark.hypothesis def test_create_list_strategy_deprecated() -> None: with pytest.deprecated_call(), pytest.warns(NonInteractiveExampleWarning): result = create_list_strategy(size=5) diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index 4cb8b3f5106f..f13898bb26c5 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -278,13 +278,17 @@ def test_assert_frame_equal_pass() -> None: assert_frame_equal(df1, df2) -def test_assert_frame_equal_types() -> None: +@pytest.mark.parametrize( + "assert_function", + [assert_frame_equal, assert_frame_not_equal], +) +def test_assert_frame_equal_types(assert_function: Any) -> None: df1 = pl.DataFrame({"a": [1, 2]}) srs1 = pl.Series(values=[1, 2], name="a") with pytest.raises( AssertionError, match=r"inputs are different \(unexpected input types\)" ): - assert_frame_equal(df1, srs1) # type: ignore[arg-type] + assert_function(df1, srs1) def test_assert_frame_equal_length_mismatch() -> None: @@ -295,6 +299,7 @@ def test_assert_frame_equal_length_mismatch() -> None: match=r"DataFrames are different \(number of rows does not match\)", ): assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_column_mismatch() -> None: @@ -304,6 +309,7 @@ def test_assert_frame_equal_column_mismatch() -> None: AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right" ): assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_column_mismatch2() -> None: @@ -314,6 +320,7 @@ def test_assert_frame_equal_column_mismatch2() -> None: match="columns \\['b', 'c'\\] in right LazyFrame, but not in left", ): assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_column_mismatch_order() -> None: @@ -323,6 +330,7 @@ def test_assert_frame_equal_column_mismatch_order() -> None: assert_frame_equal(df1, df2) assert_frame_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_check_row_order() -> None: @@ -331,25 +339,33 @@ def test_assert_frame_equal_check_row_order() -> None: with pytest.raises(AssertionError, match="value mismatch for column 'a'"): assert_frame_equal(df1, df2) + assert_frame_equal(df1, df2, check_row_order=False) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_check_row_col_order() -> None: df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) - df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) + df2 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) with pytest.raises(AssertionError, match="columns are not in the same order"): - assert_frame_equal(df1, df3, check_row_order=False) - assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) + assert_frame_equal(df1, df2, check_row_order=False) + + assert_frame_equal(df1, df2, check_row_order=False, check_column_order=False) + assert_frame_not_equal(df1, df2) -def test_assert_frame_equal_check_row_order_unsortable() -> None: +@pytest.mark.parametrize( + "assert_function", + [assert_frame_equal, assert_frame_not_equal], +) +def test_assert_frame_equal_check_row_order_unsortable(assert_function: Any) -> None: df1 = pl.DataFrame({"a": [object(), object()], "b": [3, 4]}) df2 = pl.DataFrame({"a": [object(), object()], "b": [4, 3]}) with pytest.raises( TypeError, match="cannot set `check_row_order=False`.*unsortable columns" ): - assert_frame_equal(df1, df2, check_row_order=False) + assert_function(df1, df2, check_row_order=False) def test_assert_frame_equal_dtypes_mismatch() -> None: @@ -360,11 +376,17 @@ def test_assert_frame_equal_dtypes_mismatch() -> None: with pytest.raises(AssertionError, match="dtypes do not match"): assert_frame_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2) + def test_assert_frame_not_equal() -> None: df = pl.DataFrame({"a": [1, 2]}) - with pytest.raises(AssertionError, match="frames are equal"): + with pytest.raises(AssertionError, match="DataFrames are equal"): assert_frame_not_equal(df, df) + lf = df.lazy() + with pytest.raises(AssertionError, match="LazyFrames are equal"): + assert_frame_not_equal(lf, lf) def test_assert_frame_equal_check_dtype_deprecated() -> None: @@ -437,6 +459,6 @@ def test_frame_schema_fail(): "AssertionError: DataFrames are different (value mismatch for column 'a')" in stdout ) - assert "AssertionError: frames are equal" in stdout + assert "AssertionError: DataFrames are equal" in stdout assert "AssertionError: inputs are different (unexpected input types)" in stdout assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index 92ebe13a0104..c523fe193a30 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -35,10 +35,11 @@ def test_assert_series_equal_parametric_array(data: st.DataObject) -> None: def test_compare_series_value_mismatch() -> None: srs1 = pl.Series([1, 2, 3]) srs2 = pl.Series([2, 3, 4]) - assert_series_not_equal(srs1, srs2) + with pytest.raises( - AssertionError, match=r"Series are different \(exact value mismatch\)" + AssertionError, + match=r"Series are different \(exact value mismatch\)", ): assert_series_equal(srs1, srs2) @@ -46,25 +47,33 @@ def test_compare_series_value_mismatch() -> None: def test_compare_series_empty_equal() -> None: srs1 = pl.Series([]) srs2 = pl.Series(()) - assert_series_equal(srs1, srs2) - with pytest.raises(AssertionError): + + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): assert_series_not_equal(srs1, srs2) def test_assert_series_equal_check_order() -> None: srs1 = pl.Series([1, 2, 3, None]) srs2 = pl.Series([2, None, 3, 1]) - assert_series_equal(srs1, srs2, check_order=False) - with pytest.raises(AssertionError): + + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): assert_series_not_equal(srs1, srs2, check_order=False) def test_assert_series_equal_check_order_unsortable_type() -> None: s = pl.Series([object(), object()]) - - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match="cannot set `check_order=False` on Series with unsortable data type", + ): assert_series_equal(s, s, check_order=False) @@ -123,32 +132,45 @@ def test_compare_series_value_mismatch_string() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match=r"Series are different \(exact value mismatch\)" + AssertionError, + match=r"Series are different \(exact value mismatch\)", ): assert_series_equal(srs1, srs2) -def test_compare_series_type_mismatch() -> None: +def test_compare_series_dtype_mismatch() -> None: srs1 = pl.Series([1, 2, 3]) - srs2 = pl.DataFrame({"col1": [2, 3, 4]}) + srs2 = pl.Series([1.0, 2.0, 3.0]) + assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match=r"inputs are different \(unexpected input types\)" + AssertionError, + match=r"Series are different \(dtype mismatch\)", ): - assert_series_equal(srs1, srs2) # type: ignore[arg-type] + assert_series_equal(srs1, srs2) + + +@pytest.mark.parametrize( + "assert_function", [assert_series_equal, assert_series_not_equal] +) +def test_compare_series_input_type_mismatch(assert_function: Any) -> None: + srs1 = pl.Series([1, 2, 3]) + srs2 = pl.DataFrame({"col1": [2, 3, 4]}) - srs3 = pl.Series([1.0, 2.0, 3.0]) - assert_series_not_equal(srs1, srs3) with pytest.raises( - AssertionError, match=r"Series are different \(dtype mismatch\)" + AssertionError, + match=r"inputs are different \(unexpected input types\)", ): - assert_series_equal(srs1, srs3) + assert_function(srs1, srs2) def test_compare_series_name_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"): + with pytest.raises( + AssertionError, + match=r"Series are different \(name mismatch\)", + ): assert_series_equal(srs1, srs2) @@ -158,7 +180,8 @@ def test_compare_series_length_mismatch() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match=r"Series are different \(length mismatch\)" + AssertionError, + match=r"Series are different \(length mismatch\)", ): assert_series_equal(srs1, srs2) @@ -167,7 +190,8 @@ def test_compare_series_value_exact_mismatch() -> None: srs1 = pl.Series([1.0, 2.0, 3.0]) srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) with pytest.raises( - AssertionError, match=r"Series are different \(exact value mismatch\)" + AssertionError, + match=r"Series are different \(exact value mismatch\)", ): assert_series_equal(srs1, srs2, check_exact=True) @@ -537,7 +561,10 @@ def test_assert_series_equal_full_series() -> None: def test_assert_series_not_equal() -> None: s = pl.Series("a", [1, 2]) - with pytest.raises(AssertionError, match="Series are equal"): + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): assert_series_not_equal(s, s) @@ -546,7 +573,10 @@ def test_assert_series_equal_nested_list_float() -> None: s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64)) s2 = pl.Series([[1.0, 2.0], [3.0, 4.9]], dtype=pl.List(pl.Float64)) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(nested value mismatch\)", + ): assert_series_equal(s1, s2) @@ -560,7 +590,10 @@ def test_assert_series_equal_nested_struct_float() -> None: dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), ) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(nested value mismatch\)", + ): assert_series_equal(s1, s2) @@ -570,7 +603,10 @@ def test_assert_series_equal_full_null_incompatible_dtypes_raises() -> None: # You could argue this should pass, but it's rare enough not to warrant the # additional check - with pytest.raises(AssertionError, match="incompatible data types"): + with pytest.raises( + AssertionError, + match="incompatible data types", + ): assert_series_equal(s1, s2, check_dtypes=False) @@ -595,9 +631,16 @@ def test_assert_series_equal_uint_overflow() -> None: s1 = pl.Series([1, 2, 3], dtype=pl.UInt8) s2 = pl.Series([2, 3, 4], dtype=pl.UInt8) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=0) - with pytest.raises(AssertionError): + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=1) left = pl.Series( @@ -616,7 +659,10 @@ def test_assert_series_equal_uint_always_checked_exactly() -> None: s1 = pl.Series([1, 3], dtype=pl.UInt8) s2 = pl.Series([2, 4], dtype=pl.Int64) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=1, check_dtypes=False) @@ -624,9 +670,15 @@ def test_assert_series_equal_nested_int_always_checked_exactly() -> None: s1 = pl.Series([[1, 2], [3, 4]]) s2 = pl.Series([[1, 2], [3, 5]]) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=1) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, check_exact=True) @@ -635,7 +687,9 @@ def test_assert_series_equal_array_equal(check_exact: bool) -> None: s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.Array(pl.Float64, 2)) s2 = pl.Series([[1.0, 2.0], [3.0, 4.2]], dtype=pl.Array(pl.Float64, 2)) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, match=r"Series are different \(nested value mismatch\)" + ): assert_series_equal(s1, s2, check_exact=check_exact) diff --git a/py-polars/tests/unit/utils/test_utils.py b/py-polars/tests/unit/utils/test_utils.py index 603d0a2e959a..f6eca92215da 100644 --- a/py-polars/tests/unit/utils/test_utils.py +++ b/py-polars/tests/unit/utils/test_utils.py @@ -186,7 +186,7 @@ def test_parse_version(v1: Any, v2: Any) -> None: assert parse_version(v2) < parse_version(v1) -@pytest.mark.slow() +@pytest.mark.slow def test_in_notebook() -> None: # private function, but easier to test this separately and mock it in the callers assert not _in_notebook() diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a6e580b10775..ef33a7d39711 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-07-26" +channel = "nightly-2024-08-26"