Skip to content

Commit

Permalink
feat(pydantic): allow pydantic model instances as defaults (#64) (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored Mar 18, 2024
2 parents 40687f5 + b841268 commit d7289db
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
[flake8]
# Flake 8 linting configuration, not supported in pyproject.toml

ignore = E203,W503
ignore = E203,W503,E701
max-line-length = 120
exclude =
.svn,
Expand Down
10 changes: 8 additions & 2 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,9 @@ def handles_type(cls, py_type: Type) -> bool:

# Support for `X | Y` syntax available in Python 3.10+
# equivalent to `typing.Union[X, Y]`
if getattr(types, "UnionType", None):
return origin == Union or origin == types.UnionType # noqa: E721
union_type = getattr(types, "UnionType", None)
if union_type:
return origin == Union or origin == union_type
return origin == Union

def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, options: Option = Option(0)):
Expand Down Expand Up @@ -862,6 +863,7 @@ def _record_field(self, py_field: dataclasses.Field) -> RecordField:
default=default,
options=self.options,
)

return field_obj


Expand Down Expand Up @@ -901,6 +903,10 @@ def _record_field(self, name: str, py_field: pydantic.fields.FieldInfo) -> Recor
)
return field_obj

def make_default(self, py_default: pydantic.BaseModel) -> JSONObj:
"""Return an Avro schema compliant default value for a given Python value"""
return {key: _schema_obj(value.__class__).make_default(value) for key, value in py_default}

def _annotation(self, field_name: str) -> Type:
"""
Fetch the raw annotation for a given field name
Expand Down
9 changes: 3 additions & 6 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def test_str_annotated():


def test_str_subclass():
class PyType(str):
...
class PyType(str): ...

expected = {
"type": "string",
Expand All @@ -55,8 +54,7 @@ class PyType(str):


def test_str_subclass_namespaced():
class PyType(str):
...
class PyType(str): ...

expected = {
"type": "string",
Expand All @@ -68,8 +66,7 @@ class PyType(str):
def test_str_subclass_other_classes():
import packaging.version

class PyType(packaging.version.Version, str):
...
class PyType(packaging.version.Version, str): ...

expected = {
"type": "string",
Expand Down
38 changes: 38 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,41 @@ class PyType(Base):
],
}
assert_schema(PyType, expected)


def test_base_model_defaults():
class Default(pydantic.BaseModel):
field_a: str = "default_a"

class PyType(pydantic.BaseModel):
default: Default = pydantic.Field(..., default_factory=Default)

class Nested(pydantic.BaseModel):
py_type: PyType = pydantic.Field(..., default_factory=PyType)

expected = {
"fields": [
{
"default": {"default": {"field_a": "default_a"}},
"name": "py_type",
"type": {
"fields": [
{
"default": {"field_a": "default_a"},
"name": "default",
"type": {
"fields": [{"default": "default_a", "name": "field_a", "type": "string"}],
"name": "Default",
"type": "record",
},
}
],
"name": "PyType",
"type": "record",
},
}
],
"name": "Nested",
"type": "record",
}
assert_schema(Nested, expected)

0 comments on commit d7289db

Please sign in to comment.