Skip to content

Commit

Permalink
fix: use pydantic.BaseModel.__annotations__ when getting defaults (#68)…
Browse files Browse the repository at this point in the history
… (#69)
  • Loading branch information
faph authored Mar 19, 2024
2 parents d7289db + f457690 commit a92d04f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,13 @@ def data(self, names: NamesType) -> JSONObj:
"items": self.items_schema.data(names=names),
}

def make_default(self, py_default: collections.abc.Sequence) -> JSONArray:
"""Return an Avro schema compliant default value for a given Python Sequence
:param py_default: The Python sequence to generate a default value for.
"""
return [self.items_schema.make_default(item) for item in py_default]


class DictSchema(Schema):
"""An Avro map schema for a given Python mapping"""
Expand Down Expand Up @@ -905,7 +912,7 @@ def _record_field(self, name: str, py_field: pydantic.fields.FieldInfo) -> Recor

def make_default(self, py_default: pydantic.BaseModel) -> JSONObj:
"""Return an Avro schema compliant default value for a given Python value"""
return {key: _schema_obj(value.__class__).make_default(value) for key, value in py_default}
return {key: _schema_obj(self._annotation(key)).make_default(value) for key, value in py_default}

def _annotation(self, field_name: str) -> Type:
"""
Expand Down
19 changes: 18 additions & 1 deletion tests/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import decimal
import enum
import re
from typing import Annotated, Dict, List, Optional
from typing import Annotated, Dict, List, Optional, Tuple

import pytest
import typeguard
Expand Down Expand Up @@ -823,3 +823,20 @@ class PyType:
],
}
assert_schema(PyType, expected, do_doc=True)


def test_sequence_schema_defaults_with_items():
@dataclasses.dataclass
class PyType:
field_a: List[str] = dataclasses.field(default_factory=lambda: ["foo", "bar"])
field_b: Tuple[str, str] = dataclasses.field(default_factory=lambda: ("foo", "bar"))

expected = {
"fields": [
{"default": ["foo", "bar"], "name": "field_a", "type": {"items": "string", "type": "array"}},
{"default": ["foo", "bar"], "name": "field_b", "type": {"items": "string", "type": "array"}},
],
"name": "PyType",
"type": "record",
}
assert_schema(PyType, expected)
34 changes: 34 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,37 @@ class Nested(pydantic.BaseModel):
"type": "record",
}
assert_schema(Nested, expected)


def test_nested_base_model_list_default():
class Default(pydantic.BaseModel):
field_a: List[str] = pydantic.Field(..., default_factory=list)

class PyType(pydantic.BaseModel):
default: Default = pydantic.Field(..., default_factory=Default)

expected = {
"fields": [
{
"default": {"field_a": []},
"name": "default",
"type": {
"fields": [
{
"default": [],
"name": "field_a",
"type": {
"type": "array",
"items": "string",
},
}
],
"name": "Default",
"type": "record",
},
}
],
"name": "PyType",
"type": "record",
}
assert_schema(PyType, expected)

0 comments on commit a92d04f

Please sign in to comment.