Skip to content

Commit

Permalink
fix(llmobs): safely handle non-json serializable arguments [backport #…
Browse files Browse the repository at this point in the history
…10694 to 2.12] (#10726)

Backports #10694 to 2.12.

Safely handle non-JSON serializable tag arguments in LLMObs.annotate()
and the OpenAI/LangChain/Bedrock/Anthropic integrations.
- LLMObs.annotate(): we previously just discarded the entire argument to
annotate if the argument was non-JSON serializable. Now, we safely
convert non-JSON serializable fields/objects to a default placeholder
text, meaning users can still send *some* data even if some of it may be
invalid.
- Same idea with each integration, we ensure we safely handle non-JSON
serializable args and default to placeholder texts if necessary.
- We've moved all json.dumps() call into a private helper `safe_json()`
which does the above for us.

Note: This PR removes some tests in `test_llmobs_service.py` regarding
truly unserializable objects as this is highly unlikely, someone would
have to go out of their way to make a truly unserializable object (i.e.
override `__repr__()` with a non-json serializable value). We still
catch these so any resulting crashes should not be from our code.

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
Yun-Kim committed Sep 23, 2024
1 parent 845719b commit bf33c3c
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 230 deletions.
10 changes: 4 additions & 6 deletions ddtrace/contrib/internal/anthropic/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from typing import Any
from typing import Optional

from ddtrace.internal.logger import get_logger
from ddtrace.llmobs._integrations.anthropic import _get_attr
from ddtrace.llmobs._utils import _unserializable_default_repr
from ddtrace.llmobs._utils import safe_json


log = get_logger(__name__)
Expand Down Expand Up @@ -39,7 +38,7 @@ def tag_tool_use_input_on_span(integration, span, chat_input, message_idx, block
)
span.set_tag_str(
"anthropic.request.messages.%d.content.%d.tool_call.input" % (message_idx, block_idx),
integration.trunc(json.dumps(_get_attr(chat_input, "input", {}), default=_unserializable_default_repr)),
integration.trunc(safe_json(_get_attr(chat_input, "input", {}))),
)


Expand Down Expand Up @@ -80,8 +79,7 @@ def tag_tool_use_output_on_span(integration, span, chat_completion, idx):
span.set_tag_str("anthropic.response.completions.content.%d.tool_call.name" % idx, str(tool_name))
if tool_inputs:
span.set_tag_str(
"anthropic.response.completions.content.%d.tool_call.input" % idx,
integration.trunc(json.dumps(tool_inputs, default=_unserializable_default_repr)),
"anthropic.response.completions.content.%d.tool_call.input" % idx, integration.trunc(safe_json(tool_inputs))
)


Expand All @@ -92,7 +90,7 @@ def tag_params_on_span(span, kwargs, integration):
span.set_tag_str("anthropic.request.system", integration.trunc(str(v)))
elif k not in ("messages", "model"):
tagged_params[k] = v
span.set_tag_str("anthropic.request.parameters", json.dumps(tagged_params, default=_unserializable_default_repr))
span.set_tag_str("anthropic.request.parameters", safe_json(tagged_params))


def _extract_api_key(instance: Any) -> Optional[str]:
Expand Down
16 changes: 10 additions & 6 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,11 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
"messages": [
[
{
"content": message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", "")),
"content": (
message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", ""))
),
"message_type": message.__class__.__name__,
}
for message in messages
Expand Down Expand Up @@ -596,9 +598,11 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
"messages": [
[
{
"content": message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", "")),
"content": (
message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", ""))
),
"message_type": message.__class__.__name__,
}
for message in messages
Expand Down
17 changes: 8 additions & 9 deletions ddtrace/llmobs/_integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import SPAN_KIND
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY

from .base import BaseLLMIntegration
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._utils import safe_json


log = get_logger(__name__)
Expand Down Expand Up @@ -70,18 +70,17 @@ def llmobs_set_tags(

span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag("anthropic.request.model") or "")
span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages))
span.set_tag_str(METADATA, json.dumps(parameters))
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))
span.set_tag_str(METADATA, safe_json(parameters))
span.set_tag_str(MODEL_PROVIDER, "anthropic")
if err or resp is None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
else:
output_messages = self._extract_output_message(resp)
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages))

span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages))
usage = self._get_llmobs_metrics_tags(span)
if usage != {}:
span.set_tag_str(METRICS, json.dumps(usage))
if usage:
span.set_tag_str(METRICS, safe_json(usage))

def _extract_input_message(self, messages, system_prompt=None):
"""Extract input messages from the stored prompt.
Expand Down
13 changes: 7 additions & 6 deletions ddtrace/llmobs/_integrations/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any
from typing import Dict
from typing import Optional
Expand All @@ -19,6 +18,7 @@
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY
from ddtrace.llmobs._integrations import BaseLLMIntegration
from ddtrace.llmobs._utils import _get_llmobs_parent_id
from ddtrace.llmobs._utils import safe_json


log = get_logger(__name__)
Expand Down Expand Up @@ -50,14 +50,15 @@ def llmobs_set_tags(
span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag("bedrock.request.model") or "")
span.set_tag_str(MODEL_PROVIDER, span.get_tag("bedrock.request.model_provider") or "")
span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages))
span.set_tag_str(METADATA, json.dumps(parameters))
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))
span.set_tag_str(METADATA, safe_json(parameters))
if err or formatted_response is None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
else:
output_messages = self._extract_output_message(formatted_response)
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages))
span.set_tag_str(METRICS, json.dumps(self._llmobs_metrics(span, formatted_response)))
span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages))
metrics = self._llmobs_metrics(span, formatted_response)
span.set_tag_str(METRICS, safe_json(metrics))

@staticmethod
def _llmobs_metrics(span: Span, formatted_response: Optional[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down
48 changes: 15 additions & 33 deletions ddtrace/llmobs/_integrations/langchain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any
from typing import Dict
from typing import List
Expand All @@ -20,9 +19,9 @@
from ddtrace.llmobs._constants import OUTPUT_MESSAGES
from ddtrace.llmobs._constants import OUTPUT_VALUE
from ddtrace.llmobs._constants import SPAN_KIND

from ..utils import Document
from .base import BaseLLMIntegration
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._utils import safe_json
from ddtrace.llmobs.utils import Document


log = get_logger(__name__)
Expand Down Expand Up @@ -89,7 +88,7 @@ def llmobs_set_tags(
self._llmobs_set_meta_tags_from_chain(span, inputs, response, error)
elif operation == "embedding":
self._llmobs_set_meta_tags_from_embedding(span, inputs, response, error, is_workflow=is_workflow)
span.set_tag_str(METRICS, json.dumps({}))
span.set_tag_str(METRICS, safe_json({}))

def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) -> None:
if not model_provider:
Expand All @@ -110,7 +109,7 @@ def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None)
if max_tokens is not None and max_tokens != "None":
metadata["max_tokens"] = int(max_tokens)
if metadata:
span.set_tag_str(METADATA, json.dumps(metadata))
span.set_tag_str(METADATA, safe_json(metadata))

def _llmobs_set_meta_tags_from_llm(
self, span: Span, prompts: List[Any], completions: Any, err: bool = False, is_workflow: bool = False
Expand All @@ -125,12 +124,12 @@ def _llmobs_set_meta_tags_from_llm(
if isinstance(prompts, str):
prompts = [prompts]

span.set_tag_str(input_tag_key, json.dumps([{"content": str(prompt)} for prompt in prompts]))
span.set_tag_str(input_tag_key, safe_json([{"content": str(prompt)} for prompt in prompts]))

message_content = [{"content": ""}]
if not err:
message_content = [{"content": completion[0].text} for completion in completions.generations]
span.set_tag_str(output_tag_key, json.dumps(message_content))
span.set_tag_str(output_tag_key, safe_json(message_content))

def _llmobs_set_meta_tags_from_chat_model(
self,
Expand All @@ -157,7 +156,7 @@ def _llmobs_set_meta_tags_from_chat_model(
"role": getattr(message, "role", ROLE_MAPPING.get(message.type, "")),
}
)
span.set_tag_str(input_tag_key, json.dumps(input_messages))
span.set_tag_str(input_tag_key, safe_json(input_messages))

output_messages = [{"content": ""}]
if not err:
Expand All @@ -172,7 +171,7 @@ def _llmobs_set_meta_tags_from_chat_model(
"role": role,
}
)
span.set_tag_str(output_tag_key, json.dumps(output_messages))
span.set_tag_str(output_tag_key, safe_json(output_messages))

def _llmobs_set_meta_tags_from_chain(
self,
Expand All @@ -184,25 +183,13 @@ def _llmobs_set_meta_tags_from_chain(
span.set_tag_str(SPAN_KIND, "workflow")

if inputs is not None:
try:
formatted_inputs = self.format_io(inputs)
if isinstance(formatted_inputs, str):
span.set_tag_str(INPUT_VALUE, formatted_inputs)
else:
span.set_tag_str(INPUT_VALUE, json.dumps(self.format_io(inputs)))
except TypeError:
log.warning("Failed to serialize chain input data to JSON")
formatted_inputs = self.format_io(inputs)
span.set_tag_str(INPUT_VALUE, safe_json(formatted_inputs))
if error:
span.set_tag_str(OUTPUT_VALUE, "")
elif outputs is not None:
try:
formatted_outputs = self.format_io(outputs)
if isinstance(formatted_outputs, str):
span.set_tag_str(OUTPUT_VALUE, formatted_outputs)
else:
span.set_tag_str(OUTPUT_VALUE, json.dumps(self.format_io(outputs)))
except TypeError:
log.warning("Failed to serialize chain output data to JSON")
formatted_outputs = self.format_io(outputs)
span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs))

def _llmobs_set_meta_tags_from_embedding(
self,
Expand All @@ -227,17 +214,12 @@ def _llmobs_set_meta_tags_from_embedding(
):
if is_workflow:
formatted_inputs = self.format_io(input_texts)
formatted_str = (
formatted_inputs
if isinstance(formatted_inputs, str)
else json.dumps(self.format_io(input_texts))
)
span.set_tag_str(input_tag_key, formatted_str)
span.set_tag_str(input_tag_key, safe_json(formatted_inputs))
else:
if isinstance(input_texts, str):
input_texts = [input_texts]
input_documents = [Document(text=str(doc)) for doc in input_texts]
span.set_tag_str(input_tag_key, json.dumps(input_documents))
span.set_tag_str(input_tag_key, safe_json(input_documents))
except TypeError:
log.warning("Failed to serialize embedding input data to JSON")
if error:
Expand Down
42 changes: 19 additions & 23 deletions ddtrace/llmobs/_integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from ddtrace.llmobs._constants import OUTPUT_VALUE
from ddtrace.llmobs._constants import SPAN_KIND
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY
from ddtrace.llmobs._integrations.anthropic import _get_attr
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._utils import _unserializable_default_repr
from ddtrace.llmobs._utils import safe_json
from ddtrace.llmobs.utils import Document
from ddtrace.pin import Pin

Expand Down Expand Up @@ -151,9 +152,8 @@ def llmobs_set_tags(
self._llmobs_set_meta_tags_from_chat(resp, err, kwargs, streamed_completions, span)
elif operation == "embedding":
self._llmobs_set_meta_tags_from_embedding(resp, err, kwargs, span)
span.set_tag_str(
METRICS, json.dumps(self._set_llmobs_metrics_tags(span, resp, streamed_completions is not None))
)
metrics = self._set_llmobs_metrics_tags(span, resp, streamed_completions is not None)
span.set_tag_str(METRICS, safe_json(metrics))

@staticmethod
def _llmobs_set_meta_tags_from_completion(
Expand All @@ -163,20 +163,19 @@ def _llmobs_set_meta_tags_from_completion(
prompt = kwargs.get("prompt", "")
if isinstance(prompt, str):
prompt = [prompt]
span.set_tag_str(INPUT_MESSAGES, json.dumps([{"content": str(p)} for p in prompt]))
span.set_tag_str(INPUT_MESSAGES, safe_json([{"content": str(p)} for p in prompt]))

parameters = {k: v for k, v in kwargs.items() if k not in ("model", "prompt")}
span.set_tag_str(METADATA, json.dumps(parameters, default=_unserializable_default_repr))
span.set_tag_str(METADATA, safe_json(parameters))

if err is not None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
return
if streamed_completions:
span.set_tag_str(
OUTPUT_MESSAGES, json.dumps([{"content": choice["text"]} for choice in streamed_completions])
)
return
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": choice.text} for choice in resp.choices]))
messages = [{"content": _get_attr(choice, "text", "")} for choice in streamed_completions]
else:
messages = [{"content": _get_attr(choice, "text", "")} for choice in resp.choices]
span.set_tag_str(OUTPUT_MESSAGES, safe_json(messages))

@staticmethod
def _llmobs_set_meta_tags_from_chat(
Expand All @@ -185,17 +184,14 @@ def _llmobs_set_meta_tags_from_chat(
"""Extract prompt/response tags from a chat completion and set them as temporary "_ml_obs.meta.*" tags."""
input_messages = []
for m in kwargs.get("messages", []):
if isinstance(m, dict):
input_messages.append({"content": str(m.get("content", "")), "role": str(m.get("role", ""))})
continue
input_messages.append({"content": str(getattr(m, "content", "")), "role": str(getattr(m, "role", ""))})
span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages))
input_messages.append({"content": str(_get_attr(m, "content", "")), "role": str(_get_attr(m, "role", ""))})
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))

parameters = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools", "functions")}
span.set_tag_str(METADATA, json.dumps(parameters, default=_unserializable_default_repr))
span.set_tag_str(METADATA, safe_json(parameters))

if err is not None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
return
if streamed_messages:
messages = []
Expand All @@ -204,7 +200,7 @@ def _llmobs_set_meta_tags_from_chat(
messages.append({"content": message["formatted_content"], "role": message["role"]})
continue
messages.append({"content": message["content"], "role": message["role"]})
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(messages))
span.set_tag_str(OUTPUT_MESSAGES, safe_json(messages))
return
output_messages = []
for idx, choice in enumerate(resp.choices):
Expand Down Expand Up @@ -234,7 +230,7 @@ def _llmobs_set_meta_tags_from_chat(
output_messages.append({"content": content, "role": choice.message.role, "tool_calls": tool_calls_info})
else:
output_messages.append({"content": content, "role": choice.message.role})
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages))
span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages))

@staticmethod
def _llmobs_set_meta_tags_from_embedding(resp: Any, err: Any, kwargs: Dict[str, Any], span: Span) -> None:
Expand All @@ -243,15 +239,15 @@ def _llmobs_set_meta_tags_from_embedding(resp: Any, err: Any, kwargs: Dict[str,
metadata = {"encoding_format": encoding_format}
if kwargs.get("dimensions"):
metadata["dimensions"] = kwargs.get("dimensions")
span.set_tag_str(METADATA, json.dumps(metadata))
span.set_tag_str(METADATA, safe_json(metadata))

embedding_inputs = kwargs.get("input", "")
if isinstance(embedding_inputs, str) or isinstance(embedding_inputs[0], int):
embedding_inputs = [embedding_inputs]
input_documents = []
for doc in embedding_inputs:
input_documents.append(Document(text=str(doc)))
span.set_tag_str(INPUT_DOCUMENTS, json.dumps(input_documents))
span.set_tag_str(INPUT_DOCUMENTS, safe_json(input_documents))

if err is not None:
return
Expand Down
Loading

0 comments on commit bf33c3c

Please sign in to comment.