Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MINOR] Support Pydantic model title #73

Merged
merged 4 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import decimal
import enum
import inspect
import re
import sys
import types
import uuid
Expand Down Expand Up @@ -108,8 +109,14 @@ class Option(enum.Flag):
#: Do not populate ``doc`` schema attributes based on Python docstrings
NO_DOC = enum.auto()

#: Use the alias specified in a classes ``Field`` instead of the field's name.
#: This currently only affects Pydantic Models
#: Use an alias specified as part of a class instead of the class name itself.
#: This currently affects Pydantic models only.
#: See https://docs.pydantic.dev/dev/api/config/#pydantic.config.ConfigDict.title
USE_CLASS_ALIAS = enum.auto()

#: Use the alias specified in a class field instead of the field/attribute name itself.
#: This currently affects Pydantic models only.
#: See https://docs.pydantic.dev/dev/api/fields/#pydantic.fields.Field
USE_FIELD_ALIAS = enum.auto()


Expand Down Expand Up @@ -162,6 +169,17 @@ def _schema_obj(py_type: Type, namespace: Optional[str] = None, options: Option
raise TypeNotSupportedError(f"Cannot generate Avro schema for Python type {py_type}")


# See https://avro.apache.org/docs/1.11.1/specification/#names
_AVRO_NAME_PATTERN = re.compile(r"^[A-Za-z]([A-Za-z0-9_])*$")


def validate_name(value: str) -> str:
"""Validate (and return) whether a given string is a valid Avro name"""
if not re.match(_AVRO_NAME_PATTERN, value):
raise ValueError(f"'{value}' is not a valid Avro name")
return value


class Schema(abc.ABC):
"""Schema base"""

Expand Down Expand Up @@ -690,6 +708,16 @@ def __str__(self):
"""Human rendering of the schema"""
return self.fullname

@property
def name(self):
"""Return the schema name"""
return self._name

@name.setter
def name(self, value: str):
"""Validate and set the schema name"""
self._name = validate_name(value)

@property
def fullname(self):
"""The schema's full name including the namespace if set"""
Expand Down Expand Up @@ -897,6 +925,8 @@ def __init__(self, py_type: Type[pydantic.BaseModel], namespace: Optional[str] =
:param options: Schema generation options.
"""
super().__init__(py_type, namespace=namespace, options=options)
if Option.USE_CLASS_ALIAS in self.options:
self.name = py_type.model_config.get("title") or self.name
self.py_fields = py_type.model_fields
self.record_fields = [self._record_field(name, field) for name, field in self.py_fields.items()]

Expand Down
34 changes: 33 additions & 1 deletion tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_field_alias_generator():
class PyType(pydantic.BaseModel):
field_a: str

model_config = {"alias_generator": lambda x: x.upper()}
model_config = pydantic.ConfigDict(alias_generator=lambda x: x.upper())

expected = {
"type": "record",
Expand All @@ -441,6 +441,38 @@ class PyType(pydantic.BaseModel):
assert_schema(PyType, expected, options=pas.Option.USE_FIELD_ALIAS)


def test_class_title():
class PyType(pydantic.BaseModel):
model_config = pydantic.ConfigDict(title="PyTitle")

expected = {
"type": "record",
"name": "PyTitle",
"fields": [],
}
assert_schema(PyType, expected, options=pas.Option.USE_CLASS_ALIAS)


def test_class_title_not_set():
class PyType(pydantic.BaseModel):
model_config = pydantic.ConfigDict()

expected = {
"type": "record",
"name": "PyType",
"fields": [],
}
assert_schema(PyType, expected, options=pas.Option.USE_CLASS_ALIAS)


def test_class_title_with_space():
class PyType(pydantic.BaseModel):
model_config = pydantic.ConfigDict(title="Py Title")

with pytest.raises(ValueError, match="'Py Title' is not a valid Avro name"):
assert_schema(PyType, {}, options=pas.Option.USE_CLASS_ALIAS)


def test_annotated_decimal():
class PyType(pydantic.BaseModel):
field_a: Annotated[
Expand Down
Loading