diff --git a/pyproject.toml b/pyproject.toml index 2998f20..3003ac1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ requires-python = ">=3.9" dependencies = [ "avro~=1.10", "memoization~=0.4", + "more-itertools~=10.0", "orjson~=3.5", "typeguard~=4.0", ] diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 6701e87..b85e4c5 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -44,6 +44,7 @@ get_type_hints, ) +import more_itertools import orjson import typeguard @@ -696,9 +697,17 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o args = get_args(py_type) self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args] - def data(self, names: NamesType) -> JSONArray: + def data(self, names: NamesType) -> JSONType: """Return the schema data""" - return [schema.data(names=names) for schema in self.item_schemas] + # Render the item schemas + schemas = (item_schema.data(names=names) for item_schema in self.item_schemas) + # We need to deduplicate the schemas **after** rendering. This is because **different** Python types might + # result in the **same** Avro schema. Preserving order as order may be significant in an Avro schema. + unique_schemas = list(more_itertools.unique_everseen(schemas)) + if len(unique_schemas) > 1: + return unique_schemas + else: + return unique_schemas[0] def sort_item_schemas(self, default_value: Any) -> None: """Re-order the union's schemas such that the first item corresponds with a record field's default value""" diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 7653414..4969b72 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -352,6 +352,18 @@ def test_union_of_union_string_int(): assert_schema(py_type, expected) +def test_union_str_str(): + py_type = Union[str, str] + expected = "string" + assert_schema(py_type, expected) + + +def test_union_str_annotated_str(): + py_type = Union[str, Annotated[str, ...]] + expected = "string" + assert_schema(py_type, expected) + + def test_literal_different_types(): py_type = Literal["", 42] with pytest.raises(