From 6267bb6b52eafd1651ad0936464eabd03667900d Mon Sep 17 00:00:00 2001 From: Daniel Gellert Date: Mon, 18 Mar 2024 17:54:57 +0100 Subject: [PATCH] feat: make_default of SequenceSchema can cope with default items (#65) --- src/py_avro_schema/_schemas.py | 6 +++++- tests/test_dataclass.py | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index d263390..202f4d8 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -567,6 +567,10 @@ def data(self, names: NamesType) -> JSONObj: "items": self.items_schema.data(names=names), } + def make_default(self, py_default: Any) -> Any: + origin = get_origin(self.py_type) or self.py_type + return origin(self.items_schema.make_default(item) for item in py_default) + class DictSchema(Schema): """An Avro map schema for a given Python mapping""" @@ -905,7 +909,7 @@ def _record_field(self, name: str, py_field: pydantic.fields.FieldInfo) -> Recor def make_default(self, py_default: pydantic.BaseModel) -> JSONObj: """Return an Avro schema compliant default value for a given Python value""" - return {key: _schema_obj(py_default.__annotations__[key]).make_default(value) for key, value in py_default} + return {key: _schema_obj(self._annotation(key)).make_default(value) for key, value in py_default} def _annotation(self, field_name: str) -> Type: """ diff --git a/tests/test_dataclass.py b/tests/test_dataclass.py index 0ee9cd5..22ebf9e 100644 --- a/tests/test_dataclass.py +++ b/tests/test_dataclass.py @@ -14,7 +14,7 @@ import decimal import enum import re -from typing import Annotated, Dict, List, Optional +from typing import Annotated, Dict, List, Optional, Tuple import pytest import typeguard @@ -823,3 +823,20 @@ class PyType: ], } assert_schema(PyType, expected, do_doc=True) + + +def test_sequence_schema_defaults_with_items(): + @dataclasses.dataclass + class PyType: + field_a: List[str] = dataclasses.field(default_factory=lambda: ["foo", "bar"]) + field_b: Tuple[str, str] = dataclasses.field(default_factory=lambda: ("foo", "bar")) + + expected = { + "fields": [ + {"default": ["foo", "bar"], "name": "field_a", "type": {"items": "string", "type": "array"}}, + {"default": ("foo", "bar"), "name": "field_b", "type": {"items": "string", "type": "array"}}, + ], + "name": "PyType", + "type": "record", + } + assert_schema(PyType, expected)