From cdc0f3b1490abdccc8b3c6d9c8706d37b1f5a59e Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Sat, 10 Aug 2024 21:46:59 +0100 Subject: [PATCH 1/2] serialize set and frozenset Co-authored-by: Lily Foote --- CHANGELOG.md | 1 + src/de.rs | 94 ++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7544af4..4283d66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ ### Fixed - Fix overflow error attempting to depythonize `u64` values greater than `i64::MAX` to types like `serde_json::Value` +- Fix deserializing `set` and `frozenset` into Rust sequences ## 0.21.1 - 2024-04-02 diff --git a/src/de.rs b/src/de.rs index b1369c6..51fd155 100644 --- a/src/de.rs +++ b/src/de.rs @@ -32,15 +32,28 @@ impl<'a, 'py> Depythonizer<'a, 'py> { Depythonizer { input } } - fn sequence_access(&self, expected_len: Option) -> Result> { - let seq = self.input.downcast::()?; + fn sequence_access(&self, expected_len: Option) -> Result> { + let seq = match self.input.downcast::() { + Ok(seq) => seq, + Err(e) => { + return if let Ok(set) = self.input.downcast::() { + Ok(SequenceAccess::Set(PySetAsSequence::from_set(&set))) + } else if let Ok(frozenset) = self.input.downcast::() { + Ok(SequenceAccess::Set(PySetAsSequence::from_frozenset( + &frozenset, + ))) + } else { + Err(e.into()) + } + } + }; let len = self.input.len()?; match expected_len { Some(expected) if expected != len => { Err(PythonizeError::incorrect_sequence_length(expected, len)) } - _ => Ok(PySequenceAccess::new(seq, len)), + _ => Ok(SequenceAccess::Sequence(PySequenceAccess::new(seq, len))), } } @@ -238,14 +251,20 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - visitor.visit_seq(self.sequence_access(None)?) + match self.sequence_access(None)? { + SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), + SequenceAccess::Set(set) => visitor.visit_seq(set), + } } fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: de::Visitor<'de>, { - visitor.visit_seq(self.sequence_access(Some(len))?) + match self.sequence_access(Some(len))? { + SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), + SequenceAccess::Set(set) => visitor.visit_seq(set), + } } fn deserialize_tuple_struct( @@ -257,7 +276,10 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - visitor.visit_seq(self.sequence_access(Some(len))?) + match self.sequence_access(Some(len))? { + SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), + SequenceAccess::Set(set) => visitor.visit_seq(set), + } } fn deserialize_map(self, visitor: V) -> Result @@ -327,6 +349,11 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { } } +enum SequenceAccess<'a, 'py> { + Sequence(PySequenceAccess<'a, 'py>), + Set(PySetAsSequence<'py>), +} + struct PySequenceAccess<'a, 'py> { seq: &'a Bound<'py, PySequence>, index: usize, @@ -357,6 +384,40 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> { } } +struct PySetAsSequence<'py> { + iter: Bound<'py, PyIterator>, +} + +impl<'py> PySetAsSequence<'py> { + fn from_set(set: &Bound<'py, PySet>) -> Self { + Self { + iter: PyIterator::from_bound_object(&set).expect("set is always iterable"), + } + } + + fn from_frozenset(set: &Bound<'py, PyFrozenSet>) -> Self { + Self { + iter: PyIterator::from_bound_object(&set).expect("frozenset is always iterable"), + } + } +} + +impl<'de> de::SeqAccess<'de> for PySetAsSequence<'_> { + type Error = PythonizeError; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + match self.iter.next() { + Some(item) => seed + .deserialize(&mut Depythonizer::from_object(&item?)) + .map(Some), + None => Ok(None), + } + } +} + struct PyMappingAccess<'py> { keys: Bound<'py, PySequence>, values: Bound<'py, PySequence>, @@ -454,7 +515,10 @@ impl<'de> de::VariantAccess<'de> for PyEnumAccess<'_, '_> { where V: de::Visitor<'de>, { - visitor.visit_seq(self.de.sequence_access(Some(len))?) + match self.de.sequence_access(Some(len))? { + SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), + SequenceAccess::Set(set) => visitor.visit_seq(set), + } } fn struct_variant(self, _fields: &'static [&'static str], visitor: V) -> Result @@ -606,6 +670,22 @@ mod test { test_de(code, &expected, &expected_json); } + #[test] + fn test_tuple_from_pyset() { + let expected = ("foo".to_string(), 5); + let expected_json = json!(["foo", 5]); + let code = "{'foo', 5}"; + test_de(code, &expected, &expected_json); + } + + #[test] + fn test_tuple_from_pyfrozenset() { + let expected = ("foo".to_string(), 5); + let expected_json = json!(["foo", 5]); + let code = "frozenset({'foo', 5})"; + test_de(code, &expected, &expected_json); + } + #[test] fn test_vec() { let expected = vec![3, 2, 1]; From 2666d63def4cd316d6a8f7a453bad16f34f57ad4 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Sat, 10 Aug 2024 22:12:35 +0100 Subject: [PATCH 2/2] only allow sets into homogeneous containers --- CHANGELOG.md | 2 +- src/de.rs | 91 ++++++++++++++++++++++++---------------------------- src/error.rs | 9 ++++++ 3 files changed, 52 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4283d66..30ff327 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,7 @@ ### Fixed - Fix overflow error attempting to depythonize `u64` values greater than `i64::MAX` to types like `serde_json::Value` -- Fix deserializing `set` and `frozenset` into Rust sequences +- Fix deserializing `set` and `frozenset` into Rust homogeneous containers ## 0.21.1 - 2024-04-02 diff --git a/src/de.rs b/src/de.rs index 51fd155..8b65825 100644 --- a/src/de.rs +++ b/src/de.rs @@ -2,7 +2,7 @@ use pyo3::{types::*, Bound}; use serde::de::{self, DeserializeOwned, IntoDeserializer}; use serde::Deserialize; -use crate::error::{PythonizeError, Result}; +use crate::error::{ErrorImpl, PythonizeError, Result}; /// Attempt to convert a Python object to an instance of `T` pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result @@ -32,28 +32,28 @@ impl<'a, 'py> Depythonizer<'a, 'py> { Depythonizer { input } } - fn sequence_access(&self, expected_len: Option) -> Result> { - let seq = match self.input.downcast::() { - Ok(seq) => seq, - Err(e) => { - return if let Ok(set) = self.input.downcast::() { - Ok(SequenceAccess::Set(PySetAsSequence::from_set(&set))) - } else if let Ok(frozenset) = self.input.downcast::() { - Ok(SequenceAccess::Set(PySetAsSequence::from_frozenset( - &frozenset, - ))) - } else { - Err(e.into()) - } - } - }; + fn sequence_access(&self, expected_len: Option) -> Result> { + let seq = self.input.downcast::()?; let len = self.input.len()?; match expected_len { Some(expected) if expected != len => { Err(PythonizeError::incorrect_sequence_length(expected, len)) } - _ => Ok(SequenceAccess::Sequence(PySequenceAccess::new(seq, len))), + _ => Ok(PySequenceAccess::new(seq, len)), + } + } + + fn set_access(&self) -> Result> { + match self.input.downcast::() { + Ok(set) => Ok(PySetAsSequence::from_set(&set)), + Err(e) => { + if let Ok(f) = self.input.downcast::() { + Ok(PySetAsSequence::from_frozenset(&f)) + } else { + Err(e.into()) + } + } } } @@ -135,10 +135,9 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { self.deserialize_bytes(visitor) } else if obj.is_instance_of::() { self.deserialize_f64(visitor) - } else if obj.is_instance_of::() - || obj.is_instance_of::() - || obj.downcast::().is_ok() - { + } else if obj.is_instance_of::() || obj.is_instance_of::() { + self.deserialize_seq(visitor) + } else if obj.downcast::().is_ok() { self.deserialize_tuple(obj.len()?, visitor) } else if obj.downcast::().is_ok() { self.deserialize_map(visitor) @@ -251,9 +250,17 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - match self.sequence_access(None)? { - SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), - SequenceAccess::Set(set) => visitor.visit_seq(set), + match self.sequence_access(None) { + Ok(seq) => visitor.visit_seq(seq), + Err(e) => { + // we allow sets to be deserialized as sequences, so try that + if matches!(*e.inner, ErrorImpl::UnexpectedType(_)) { + if let Ok(set) = self.set_access() { + return visitor.visit_seq(set); + } + } + Err(e) + } } } @@ -261,10 +268,7 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - match self.sequence_access(Some(len))? { - SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), - SequenceAccess::Set(set) => visitor.visit_seq(set), - } + visitor.visit_seq(self.sequence_access(Some(len))?) } fn deserialize_tuple_struct( @@ -276,10 +280,7 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - match self.sequence_access(Some(len))? { - SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), - SequenceAccess::Set(set) => visitor.visit_seq(set), - } + visitor.visit_seq(self.sequence_access(Some(len))?) } fn deserialize_map(self, visitor: V) -> Result @@ -349,11 +350,6 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { } } -enum SequenceAccess<'a, 'py> { - Sequence(PySequenceAccess<'a, 'py>), - Set(PySetAsSequence<'py>), -} - struct PySequenceAccess<'a, 'py> { seq: &'a Bound<'py, PySequence>, index: usize, @@ -515,10 +511,7 @@ impl<'de> de::VariantAccess<'de> for PyEnumAccess<'_, '_> { where V: de::Visitor<'de>, { - match self.de.sequence_access(Some(len))? { - SequenceAccess::Sequence(seq) => visitor.visit_seq(seq), - SequenceAccess::Set(set) => visitor.visit_seq(set), - } + visitor.visit_seq(self.de.sequence_access(Some(len))?) } fn struct_variant(self, _fields: &'static [&'static str], visitor: V) -> Result @@ -671,18 +664,18 @@ mod test { } #[test] - fn test_tuple_from_pyset() { - let expected = ("foo".to_string(), 5); - let expected_json = json!(["foo", 5]); - let code = "{'foo', 5}"; + fn test_vec_from_pyset() { + let expected = vec!["foo".to_string()]; + let expected_json = json!(["foo"]); + let code = "{'foo'}"; test_de(code, &expected, &expected_json); } #[test] - fn test_tuple_from_pyfrozenset() { - let expected = ("foo".to_string(), 5); - let expected_json = json!(["foo", 5]); - let code = "frozenset({'foo', 5})"; + fn test_vec_from_pyfrozenset() { + let expected = vec!["foo".to_string()]; + let expected_json = json!(["foo"]); + let code = "frozenset({'foo'})"; test_de(code, &expected, &expected_json); } diff --git a/src/error.rs b/src/error.rs index 4aee7ea..9aa5a87 100644 --- a/src/error.rs +++ b/src/error.rs @@ -32,6 +32,15 @@ impl PythonizeError { } } + pub(crate) fn unexpected_type(t: T) -> Self + where + T: ToString, + { + Self { + inner: Box::new(ErrorImpl::UnexpectedType(t.to_string())), + } + } + pub(crate) fn dict_key_not_string() -> Self { Self { inner: Box::new(ErrorImpl::DictKeyNotString),