From 9a4d4410737673e6ac92034312d821c1c5d2c10a Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Fri, 17 May 2024 23:54:23 +0530 Subject: [PATCH 1/7] feat: add support for pyarrow non-nested arrays --- bindings/python/Cargo.toml | 5 ++- bindings/python/src/tokenizer.rs | 17 +++++++-- .../python/tests/bindings/test_tokenizer.py | 36 +++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 3b1b1bbf1..6cb667041 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] rayon = "1.10" -serde = { version = "1.0", features = [ "rc", "derive" ]} +serde = { version = "1.0", features = ["rc", "derive"] } serde_json = "1.0" libc = "0.2" env_logger = "0.11" @@ -19,6 +19,9 @@ numpy = "0.21" ndarray = "0.15" onig = { version = "6.4", default-features = false } itertools = "0.12" +arrow = { git = "https://github.com/apache/arrow-rs", branch = "master", features = [ + "pyarrow", +] } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 1c6bc9cc1..ba472105e 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -263,11 +263,24 @@ impl PyAddedToken { struct TextInputSequence<'s>(tk::InputSequence<'s>); impl<'s> FromPyObject<'s> for TextInputSequence<'s> { fn extract(ob: &'s PyAny) -> PyResult { - let err = exceptions::PyTypeError::new_err("TextInputSequence must be str"); if let Ok(s) = ob.downcast::() { Ok(Self(s.to_string_lossy().into())) } else { - Err(err) + let str_scalar_class = PyModule::import_bound(ob.py(), "pyarrow") + .map(Bound::into_gil_ref)? + .getattr("StringScalar")?; + if let Ok(true) = ob.is_instance(str_scalar_class) { + let buf = ob.call_method0("as_buffer")?; + let addr = buf.getattr("address")?.extract::()?; + let size = buf.getattr("size")?.extract::()?; + + let parts = unsafe { std::slice::from_raw_parts(addr as *const u8, size) }; + let x = String::from_utf8_lossy(&parts[..]); + Ok(Self(x.into())) + } else { + let err = exceptions::PyTypeError::new_err("TextInputSequence must be str"); + Err(err) + } } } } diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 3ac50e00c..1634cb2b8 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -1,6 +1,7 @@ import pickle import numpy as np +import pyarrow as pa import pytest from tokenizers import AddedToken, Encoding, Tokenizer @@ -200,6 +201,11 @@ def test_pair(input, is_pretokenized=False): test_pair(np.array([("My name is John", "pair"), ("My name is Georges", "pair")])) test_pair(np.array([["My name is John", "pair"], ["My name is Georges", "pair"]])) + # Pyarrow + test_single(pa.array(["My name is John", "My name is Georges"])) + test_pair(pa.array([("My name is John", "pair"), ("My name is Georges", "pair")])) + test_pair(pa.array([["My name is John", "pair"], ["My name is Georges", "pair"]])) + # PreTokenized inputs # Lists @@ -266,6 +272,36 @@ def test_pair(input, is_pretokenized=False): True, ) + # Pyarrow + test_single( + pa.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]), + True, + ) + test_single( + pa.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))), + True, + ) + test_pair( + pa.array( + [ + [["My", "name", "is", "John"], ["pair"]], + [["My", "name", "is", "Georges"], ["pair"]], + ], + dtype=object, + ), + True, + ) + test_pair( + pa.array( + ( + (("My", "name", "is", "John"), ("pair",)), + (("My", "name", "is", "Georges"), ("pair",)), + ), + dtype=object, + ), + True, + ) + # Mal formed with pytest.raises(TypeError, match="TextInputSequence must be str"): tokenizer.encode([["my", "name"]]) From 6bbf587ff6f20d9b15bba1373cc44add5c4d2b9b Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Sat, 18 May 2024 09:57:39 +0530 Subject: [PATCH 2/7] feat: add support for pyarrow arrays with pretokenized inputs --- bindings/python/src/tokenizer.rs | 39 +++++++++++++++++++ .../python/tests/bindings/test_tokenizer.py | 2 - 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index ba472105e..b61d55259 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; @@ -260,6 +261,7 @@ impl PyAddedToken { } } +#[derive(Debug)] struct TextInputSequence<'s>(tk::InputSequence<'s>); impl<'s> FromPyObject<'s> for TextInputSequence<'s> { fn extract(ob: &'s PyAny) -> PyResult { @@ -357,6 +359,40 @@ impl From for tk::InputSequence<'_> { } } +struct PyArrowUnicode<'s>(Vec>); +impl<'s> FromPyObject<'s> for PyArrowUnicode<'s> { + fn extract(ob: &'_ PyAny) -> PyResult { + let list = ob.extract::>()?; + let mut pa_str_list = Vec::with_capacity(list.len()); + for item in list.iter() { + let str_scalar_class = PyModule::import_bound(ob.py(), "pyarrow") + .map(Bound::into_gil_ref)? + .getattr("StringScalar")?; + if let Ok(true) = item.is_instance(str_scalar_class) { + let buf = item.call_method0("as_buffer")?; + let addr = buf.getattr("address")?.extract::()?; + let size = buf.getattr("size")?.extract::()?; + + let parts = unsafe { std::slice::from_raw_parts(addr as *const u8, size) }; + let cow_str = String::from_utf8_lossy(&parts[..]); + pa_str_list.push(cow_str); + } else { + let err = exceptions::PyTypeError::new_err( + "TextInputSequence is not pyarrow.StringScalar", + ); + return Err(err); + } + } + + Ok(Self(pa_str_list)) + } +} +impl<'s> From> for tk::InputSequence<'s> { + fn from(s: PyArrowUnicode<'s>) -> Self { + s.0.into() + } +} + struct PyArrayStr(Vec); impl FromPyObject<'_> for PyArrayStr { fn extract(ob: &PyAny) -> PyResult { @@ -386,6 +422,9 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> { if let Ok(seq) = ob.extract::() { return Ok(Self(seq.into())); } + if let Ok(seq) = ob.extract::() { + return Ok(Self(seq.into())); + } if let Ok(seq) = ob.extract::() { return Ok(Self(seq.into())); } diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 1634cb2b8..6fb6f250d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -287,7 +287,6 @@ def test_pair(input, is_pretokenized=False): [["My", "name", "is", "John"], ["pair"]], [["My", "name", "is", "Georges"], ["pair"]], ], - dtype=object, ), True, ) @@ -297,7 +296,6 @@ def test_pair(input, is_pretokenized=False): (("My", "name", "is", "John"), ("pair",)), (("My", "name", "is", "Georges"), ("pair",)), ), - dtype=object, ), True, ) From 48a33d71c65db42736852fd4c5793c3c1909365a Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Sat, 18 May 2024 10:19:02 +0530 Subject: [PATCH 3/7] refactor: abstract extracting from `StringScalar` logic to a struct --- bindings/python/src/tokenizer.rs | 89 ++++++++++++++++---------------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index b61d55259..ab7d5c7b2 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -261,28 +261,44 @@ impl PyAddedToken { } } -#[derive(Debug)] +struct PyArrowScalarStringInput<'s>(Cow<'s, str>); +impl<'s> FromPyObject<'s> for PyArrowScalarStringInput<'s> { + fn extract(ob: &'s PyAny) -> PyResult { + let str_scalar_class = PyModule::import_bound(ob.py(), "pyarrow") + .map(Bound::into_gil_ref)? + .getattr("StringScalar")?; + if ob.is_instance(str_scalar_class)? { + let buf = ob.call_method0("as_buffer")?; + let addr = buf.getattr("address")?.extract::()?; + let size = buf.getattr("size")?.extract::()?; + + // SAFETY address is valid because it's from the StringScalar buffer + let parts = unsafe { std::slice::from_raw_parts(addr as *const u8, size) }; + let x = String::from_utf8_lossy(&parts[..]); + Ok(Self(x.into())) + } else { + let err = + exceptions::PyTypeError::new_err("TextInputSequence must be pyarrow.StringScalar"); + Err(err) + } + } +} +impl<'s> From> for tk::InputSequence<'s> { + fn from(s: PyArrowScalarStringInput<'s>) -> Self { + s.0.into() + } +} + struct TextInputSequence<'s>(tk::InputSequence<'s>); impl<'s> FromPyObject<'s> for TextInputSequence<'s> { fn extract(ob: &'s PyAny) -> PyResult { if let Ok(s) = ob.downcast::() { Ok(Self(s.to_string_lossy().into())) + } else if let Ok(s) = ob.extract::() { + Ok(Self(s.0.into())) } else { - let str_scalar_class = PyModule::import_bound(ob.py(), "pyarrow") - .map(Bound::into_gil_ref)? - .getattr("StringScalar")?; - if let Ok(true) = ob.is_instance(str_scalar_class) { - let buf = ob.call_method0("as_buffer")?; - let addr = buf.getattr("address")?.extract::()?; - let size = buf.getattr("size")?.extract::()?; - - let parts = unsafe { std::slice::from_raw_parts(addr as *const u8, size) }; - let x = String::from_utf8_lossy(&parts[..]); - Ok(Self(x.into())) - } else { - let err = exceptions::PyTypeError::new_err("TextInputSequence must be str"); - Err(err) - } + let err = exceptions::PyTypeError::new_err("TextInputSequence must be str"); + Err(err) } } } @@ -359,36 +375,19 @@ impl From for tk::InputSequence<'_> { } } -struct PyArrowUnicode<'s>(Vec>); -impl<'s> FromPyObject<'s> for PyArrowUnicode<'s> { - fn extract(ob: &'_ PyAny) -> PyResult { - let list = ob.extract::>()?; - let mut pa_str_list = Vec::with_capacity(list.len()); - for item in list.iter() { - let str_scalar_class = PyModule::import_bound(ob.py(), "pyarrow") - .map(Bound::into_gil_ref)? - .getattr("StringScalar")?; - if let Ok(true) = item.is_instance(str_scalar_class) { - let buf = item.call_method0("as_buffer")?; - let addr = buf.getattr("address")?.extract::()?; - let size = buf.getattr("size")?.extract::()?; - - let parts = unsafe { std::slice::from_raw_parts(addr as *const u8, size) }; - let cow_str = String::from_utf8_lossy(&parts[..]); - pa_str_list.push(cow_str); - } else { - let err = exceptions::PyTypeError::new_err( - "TextInputSequence is not pyarrow.StringScalar", - ); - return Err(err); - } - } - - Ok(Self(pa_str_list)) +struct PyArrowArray<'s>(Vec>); +impl<'s> FromPyObject<'s> for PyArrowArray<'s> { + fn extract(ob: &'s PyAny) -> PyResult { + let array = ob.extract::>()?; + let str_array: Vec> = array + .iter() + .map(|item| item.extract::().map(|res| res.0)) + .collect::>>()?; + Ok(Self(str_array)) } } -impl<'s> From> for tk::InputSequence<'s> { - fn from(s: PyArrowUnicode<'s>) -> Self { +impl<'s> From> for tk::InputSequence<'s> { + fn from(s: PyArrowArray<'s>) -> Self { s.0.into() } } @@ -422,7 +421,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> { if let Ok(seq) = ob.extract::() { return Ok(Self(seq.into())); } - if let Ok(seq) = ob.extract::() { + if let Ok(seq) = ob.extract::() { return Ok(Self(seq.into())); } if let Ok(seq) = ob.extract::() { From 5e8148a2625314318449421baad02c236256bba5 Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Sat, 18 May 2024 11:12:27 +0530 Subject: [PATCH 4/7] add support for `LargeStringScalar` --- bindings/python/src/tokenizer.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index ab7d5c7b2..92420b91a 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -264,10 +264,10 @@ impl PyAddedToken { struct PyArrowScalarStringInput<'s>(Cow<'s, str>); impl<'s> FromPyObject<'s> for PyArrowScalarStringInput<'s> { fn extract(ob: &'s PyAny) -> PyResult { - let str_scalar_class = PyModule::import_bound(ob.py(), "pyarrow") - .map(Bound::into_gil_ref)? - .getattr("StringScalar")?; - if ob.is_instance(str_scalar_class)? { + let pyarrow = PyModule::import_bound(ob.py(), "pyarrow").map(Bound::into_gil_ref)?; + let str_scalar_class = pyarrow.getattr("StringScalar")?; + let large_str_scalar_class = pyarrow.getattr("LargeStringScalar")?; + if ob.is_exact_instance(str_scalar_class) || ob.is_exact_instance(large_str_scalar_class) { let buf = ob.call_method0("as_buffer")?; let addr = buf.getattr("address")?.extract::()?; let size = buf.getattr("size")?.extract::()?; From 4e0114b14359e57c76e27bff42ca33ea8ea393b7 Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Sat, 18 May 2024 11:35:11 +0530 Subject: [PATCH 5/7] chore: remove redundant `.into()` call and make lifetime of `buf_slice` clear --- bindings/python/src/tokenizer.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 92420b91a..7a5a4d6a5 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; +use std::slice; use numpy::{npyffi, PyArray1}; use pyo3::class::basic::CompareOp; @@ -273,9 +274,9 @@ impl<'s> FromPyObject<'s> for PyArrowScalarStringInput<'s> { let size = buf.getattr("size")?.extract::()?; // SAFETY address is valid because it's from the StringScalar buffer - let parts = unsafe { std::slice::from_raw_parts(addr as *const u8, size) }; - let x = String::from_utf8_lossy(&parts[..]); - Ok(Self(x.into())) + let buf_slice = unsafe { slice::from_raw_parts::<'s>(addr as *const u8, size) }; + let x = String::from_utf8_lossy(&buf_slice[..]); + Ok(Self(x)) } else { let err = exceptions::PyTypeError::new_err("TextInputSequence must be pyarrow.StringScalar"); From 086d0408165750f962112a296f26043283de47a4 Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Fri, 7 Jun 2024 18:17:15 +0530 Subject: [PATCH 6/7] chore: fix clippy error --- bindings/python/src/tokenizer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 7a5a4d6a5..c79cf6970 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -275,7 +275,7 @@ impl<'s> FromPyObject<'s> for PyArrowScalarStringInput<'s> { // SAFETY address is valid because it's from the StringScalar buffer let buf_slice = unsafe { slice::from_raw_parts::<'s>(addr as *const u8, size) }; - let x = String::from_utf8_lossy(&buf_slice[..]); + let x = String::from_utf8_lossy(buf_slice); Ok(Self(x)) } else { let err = From b94172875ffea286727d708c779bd89d446a3d7c Mon Sep 17 00:00:00 2001 From: Krithic Kumar Date: Fri, 7 Jun 2024 19:28:09 +0530 Subject: [PATCH 7/7] refactor: make pyarrow an optional dep --- bindings/python/Cargo.toml | 3 ++- bindings/python/pyproject.toml | 2 +- bindings/python/src/tokenizer.rs | 18 +++++++++++++++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 6cb667041..2e5b3c6ba 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -21,7 +21,7 @@ onig = { version = "6.4", default-features = false } itertools = "0.12" arrow = { git = "https://github.com/apache/arrow-rs", branch = "master", features = [ "pyarrow", -] } +], optional = true } [dependencies.tokenizers] path = "../../tokenizers" @@ -32,3 +32,4 @@ pyo3 = { version = "0.21", features = ["auto-initialize"] } [features] defaut = ["pyo3/extension-module"] +pyarrow = ["arrow"] diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 5cdf090fa..106c2806e 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -47,7 +47,7 @@ build-backend = "maturin" python-source = "py_src" module-name = "tokenizers.tokenizers" bindings = 'pyo3' -features = ["pyo3/extension-module"] +features = ["pyo3/extension-module", "pyarrow"] [tool.black] line-length = 119 diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index c79cf6970..290a7a870 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,6 +1,9 @@ -use std::borrow::Cow; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; + +#[cfg(feature = "pyarrow")] +use std::borrow::Cow; +#[cfg(feature = "pyarrow")] use std::slice; use numpy::{npyffi, PyArray1}; @@ -262,7 +265,9 @@ impl PyAddedToken { } } +#[cfg(feature = "pyarrow")] struct PyArrowScalarStringInput<'s>(Cow<'s, str>); +#[cfg(feature = "pyarrow")] impl<'s> FromPyObject<'s> for PyArrowScalarStringInput<'s> { fn extract(ob: &'s PyAny) -> PyResult { let pyarrow = PyModule::import_bound(ob.py(), "pyarrow").map(Bound::into_gil_ref)?; @@ -284,6 +289,7 @@ impl<'s> FromPyObject<'s> for PyArrowScalarStringInput<'s> { } } } +#[cfg(feature = "pyarrow")] impl<'s> From> for tk::InputSequence<'s> { fn from(s: PyArrowScalarStringInput<'s>) -> Self { s.0.into() @@ -295,9 +301,11 @@ impl<'s> FromPyObject<'s> for TextInputSequence<'s> { fn extract(ob: &'s PyAny) -> PyResult { if let Ok(s) = ob.downcast::() { Ok(Self(s.to_string_lossy().into())) - } else if let Ok(s) = ob.extract::() { - Ok(Self(s.0.into())) } else { + #[cfg(feature = "pyarrow")] + if let Ok(s) = ob.extract::() { + return Ok(Self(s.0.into())); + } let err = exceptions::PyTypeError::new_err("TextInputSequence must be str"); Err(err) } @@ -376,7 +384,9 @@ impl From for tk::InputSequence<'_> { } } +#[cfg(feature = "pyarrow")] struct PyArrowArray<'s>(Vec>); +#[cfg(feature = "pyarrow")] impl<'s> FromPyObject<'s> for PyArrowArray<'s> { fn extract(ob: &'s PyAny) -> PyResult { let array = ob.extract::>()?; @@ -387,6 +397,7 @@ impl<'s> FromPyObject<'s> for PyArrowArray<'s> { Ok(Self(str_array)) } } +#[cfg(feature = "pyarrow")] impl<'s> From> for tk::InputSequence<'s> { fn from(s: PyArrowArray<'s>) -> Self { s.0.into() @@ -422,6 +433,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> { if let Ok(seq) = ob.extract::() { return Ok(Self(seq.into())); } + #[cfg(feature = "pyarrow")] if let Ok(seq) = ob.extract::() { return Ok(Self(seq.into())); }