From a3b16f3a45630655e9a384a8c7070ec98ee0ec55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 20 Sep 2024 17:42:34 +0200 Subject: [PATCH] fix(batch-exports): Account for all broken things in JSON --- posthog/temporal/batch_exports/utils.py | 21 ++++++++++----- .../batch_exports/test_batch_export_utils.py | 27 ++++++++++++++++++- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index 5c2c581a90f7a..c54e983795838 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -2,6 +2,7 @@ import collections.abc import contextlib import functools +import re import typing import uuid @@ -124,27 +125,33 @@ class JsonScalar(pa.ExtensionScalar): def as_py(self) -> dict | None: """Try to convert value to Python representation. - We attempt to decode the value returned by `as_py` as JSON 3 times: - 1. As returned by `as_py`, without changes. - 2. By escaping and replacing any encoding errors. - 3. By treating the value as a string and surrouding it with quotes. + We attempt to decode the value returned by `as_py` as JSON. However, to do so safely we must + ensure the value is a valid UTF-8 string. We try to take care of the following scenarios: + * Unescaped whitespace characters (e.g. \n). + * Unescaped surrogates without a pair. + * Escaped surrogates without a pair. + * Not a JSON document (we assume it's a string and quote it). If all else fails, we will log the offending value and re-raise the decoding error. """ if self.value: value = self.value.as_py() + if re.search("([\t\n\r\f\v])", value): + value = value.encode("unicode-escape").decode("utf-8") + if not value: return None - json_bytes = value.encode("utf-8") + json_bytes = value.encode("utf-8", "replace") try: return orjson.loads(json_bytes) except orjson.JSONDecodeError: pass - json_bytes = value.encode("unicode-escape").decode("utf-8", "replace").encode("unicode-escape") + json_bytes = json_bytes.decode("raw-unicode-escape").encode("utf-8", "replace") + try: return orjson.loads(json_bytes) except orjson.JSONDecodeError: @@ -154,7 +161,7 @@ def as_py(self) -> dict | None: # Handles non-valid JSON strings like `'"$set": "Something"'` by quoting them. value = f'"{value}"' - json_bytes = value.encode("unicode-escape").decode("utf-8", "replace").encode("unicode-escape") + json_bytes = value.encode("utf-8", "replace").decode("raw-unicode-escape").encode("utf-8", "replace") try: return orjson.loads(json_bytes) except orjson.JSONDecodeError: diff --git a/posthog/temporal/tests/batch_exports/test_batch_export_utils.py b/posthog/temporal/tests/batch_exports/test_batch_export_utils.py index 968eab5c0d723..13201fc75301a 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_export_utils.py +++ b/posthog/temporal/tests/batch_exports/test_batch_export_utils.py @@ -1,11 +1,16 @@ import asyncio import datetime as dt +import pyarrow as pa import pytest import pytest_asyncio from posthog.batch_exports.models import BatchExportRun -from posthog.temporal.batch_exports.utils import make_retryable_with_exponential_backoff, set_status_to_running_task +from posthog.temporal.batch_exports.utils import ( + JsonType, + make_retryable_with_exponential_backoff, + set_status_to_running_task, +) from posthog.temporal.common.logger import bind_temporal_worker_logger from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -162,3 +167,23 @@ async def raise_value_error(): await make_retryable_with_exponential_backoff(raise_value_error, retryable_exceptions=(TypeError,))() assert counter == 1 + + +@pytest.mark.parametrize( + "input,expected", + [ + ([b'{"asdf": "\udee5\ud83e\udee5\\ud83e"}'], [{"asdf": "????"}]), + ([b'{"asdf": "\\"Hello\\" \\udfa2"}'], [{"asdf": '"Hello" ?'}]), + ([b'{"asdf": "\n"}'], [{"asdf": "\n"}]), + ( + [b'{"finally": "a", "normal": "json", "thing": 1, "bool": false}'], + [{"finally": "a", "normal": "json", "thing": 1, "bool": False}], + ), + ], +) +def test_json_type_as_py(input, expected): + array = pa.array(input) + casted_array = array.cast(JsonType()) + result = casted_array.to_pylist() + + assert result == expected