diff --git a/src/py_avro_schema/__init__.py b/src/py_avro_schema/__init__.py index 994aedf..94c943b 100644 --- a/src/py_avro_schema/__init__.py +++ b/src/py_avro_schema/__init__.py @@ -23,7 +23,7 @@ """ import importlib.metadata -from typing import Optional, Type +from typing import Any, Callable, Optional, Type, Union import memoization import orjson @@ -50,6 +50,7 @@ def generate( *, namespace: Optional[str] = None, options: Option = Option(0), + orjson_default: Union[Callable[[Any], Any], None] = None, ) -> bytes: """ Return an Avro schema as a JSON-formatted bytestring for a given Python class or instance @@ -60,11 +61,12 @@ def generate( :param namespace: The Avro namespace to add to schemas. :param options: Schema generation options as defined by :class:`Option` enum values. Specify multiple values like this: ``Option.INT_32 | Option.FLOAT_32``. + :param orjson_default: A function to serialize custom types with orjson. """ schema_dict = schema(py_type, namespace=namespace, options=options) json_options = 0 for opt in JSON_OPTIONS: if opt in options: json_options |= opt.value - schema_json = orjson.dumps(schema_dict, option=json_options) + schema_json = orjson.dumps(schema_dict, option=json_options, default=orjson_default) return schema_json diff --git a/tests/test_avro_schema.py b/tests/test_avro_schema.py index 0824ff2..d516452 100644 --- a/tests/test_avro_schema.py +++ b/tests/test_avro_schema.py @@ -10,9 +10,11 @@ # specific language governing permissions and limitations under the License. import dataclasses +from typing import Any import avro.schema import orjson +import pydantic import py_avro_schema as pas @@ -43,3 +45,43 @@ class PyType: json_data = pas.generate(PyType) assert json_data == orjson.dumps(expected) assert avro.schema.parse(json_data) + + +def test_pydantic_field_default(): + class Default(pydantic.BaseModel): + """My Default""" + + foo: str = "foo" + + class PyType(pydantic.BaseModel): + """My PyType""" + + default: Default = pydantic.Field(default_factory=Default) + + def pydantic_serializer(value: Any) -> dict: + if isinstance(value, pydantic.BaseModel): + return value.model_dump(mode="json") + raise TypeError + + expected = { + "type": "record", + "name": "PyType", + "fields": [ + { + "name": "default", + "type": { + "type": "record", + "name": "Default", + "fields": [{"name": "foo", "type": "string", "default": "foo"}], + "namespace": "test_avro_schema", + "doc": "My Default", + }, + "default": {"foo": "foo"}, + } + ], + "namespace": "test_avro_schema", + "doc": "My PyType", + } + json_data = pas.generate(PyType, orjson_default=pydantic_serializer) + assert json_data == orjson.dumps(expected) + assert avro.schema.parse(json_data)