Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(llmobs): safely handle non-json serializable arguments [backport #10694 to 2.12] #10726

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions ddtrace/contrib/internal/anthropic/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from typing import Any
from typing import Optional

from ddtrace.internal.logger import get_logger
from ddtrace.llmobs._integrations.anthropic import _get_attr
from ddtrace.llmobs._utils import _unserializable_default_repr
from ddtrace.llmobs._utils import safe_json


log = get_logger(__name__)
Expand Down Expand Up @@ -39,7 +38,7 @@ def tag_tool_use_input_on_span(integration, span, chat_input, message_idx, block
)
span.set_tag_str(
"anthropic.request.messages.%d.content.%d.tool_call.input" % (message_idx, block_idx),
integration.trunc(json.dumps(_get_attr(chat_input, "input", {}), default=_unserializable_default_repr)),
integration.trunc(safe_json(_get_attr(chat_input, "input", {}))),
)


Expand Down Expand Up @@ -80,8 +79,7 @@ def tag_tool_use_output_on_span(integration, span, chat_completion, idx):
span.set_tag_str("anthropic.response.completions.content.%d.tool_call.name" % idx, str(tool_name))
if tool_inputs:
span.set_tag_str(
"anthropic.response.completions.content.%d.tool_call.input" % idx,
integration.trunc(json.dumps(tool_inputs, default=_unserializable_default_repr)),
"anthropic.response.completions.content.%d.tool_call.input" % idx, integration.trunc(safe_json(tool_inputs))
)


Expand All @@ -92,7 +90,7 @@ def tag_params_on_span(span, kwargs, integration):
span.set_tag_str("anthropic.request.system", integration.trunc(str(v)))
elif k not in ("messages", "model"):
tagged_params[k] = v
span.set_tag_str("anthropic.request.parameters", json.dumps(tagged_params, default=_unserializable_default_repr))
span.set_tag_str("anthropic.request.parameters", safe_json(tagged_params))


def _extract_api_key(instance: Any) -> Optional[str]:
Expand Down
17 changes: 11 additions & 6 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from ddtrace.internal.utils.formats import deep_getattr
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs._integrations import LangChainIntegration
from ddtrace.llmobs._utils import safe_json
from ddtrace.pin import Pin


Expand Down Expand Up @@ -466,9 +467,11 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
"messages": [
[
{
"content": message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", "")),
"content": (
message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", ""))
),
"message_type": message.__class__.__name__,
}
for message in messages
Expand Down Expand Up @@ -596,9 +599,11 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
"messages": [
[
{
"content": message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", "")),
"content": (
message.get("content", "")
if isinstance(message, dict)
else str(getattr(message, "content", ""))
),
"message_type": message.__class__.__name__,
}
for message in messages
Expand Down
17 changes: 8 additions & 9 deletions ddtrace/llmobs/_integrations/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import SPAN_KIND
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY

from .base import BaseLLMIntegration
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._utils import safe_json


log = get_logger(__name__)
Expand Down Expand Up @@ -70,18 +70,17 @@ def llmobs_set_tags(

span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag("anthropic.request.model") or "")
span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages))
span.set_tag_str(METADATA, json.dumps(parameters))
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))
span.set_tag_str(METADATA, safe_json(parameters))
span.set_tag_str(MODEL_PROVIDER, "anthropic")
if err or resp is None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
else:
output_messages = self._extract_output_message(resp)
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages))

span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages))
usage = self._get_llmobs_metrics_tags(span)
if usage != {}:
span.set_tag_str(METRICS, json.dumps(usage))
if usage:
span.set_tag_str(METRICS, safe_json(usage))

def _extract_input_message(self, messages, system_prompt=None):
"""Extract input messages from the stored prompt.
Expand Down
13 changes: 7 additions & 6 deletions ddtrace/llmobs/_integrations/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Any
from typing import Dict
from typing import Optional
Expand All @@ -19,6 +18,7 @@
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY
from ddtrace.llmobs._integrations import BaseLLMIntegration
from ddtrace.llmobs._utils import _get_llmobs_parent_id
from ddtrace.llmobs._utils import safe_json


log = get_logger(__name__)
Expand Down Expand Up @@ -50,14 +50,15 @@ def llmobs_set_tags(
span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag("bedrock.request.model") or "")
span.set_tag_str(MODEL_PROVIDER, span.get_tag("bedrock.request.model_provider") or "")
span.set_tag_str(INPUT_MESSAGES, json.dumps(input_messages))
span.set_tag_str(METADATA, json.dumps(parameters))
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))
span.set_tag_str(METADATA, safe_json(parameters))
if err or formatted_response is None:
span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}]))
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
else:
output_messages = self._extract_output_message(formatted_response)
span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages))
span.set_tag_str(METRICS, json.dumps(self._llmobs_metrics(span, formatted_response)))
span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages))
metrics = self._llmobs_metrics(span, formatted_response)
span.set_tag_str(METRICS, safe_json(metrics))

@staticmethod
def _llmobs_metrics(span: Span, formatted_response: Optional[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down
157 changes: 157 additions & 0 deletions ddtrace/llmobs/_integrations/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional

from ddtrace import Span
from ddtrace.internal.utils import get_argument_value
from ddtrace.llmobs._constants import INPUT_MESSAGES
from ddtrace.llmobs._constants import INPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import METADATA
from ddtrace.llmobs._constants import METRICS
from ddtrace.llmobs._constants import MODEL_NAME
from ddtrace.llmobs._constants import MODEL_PROVIDER
from ddtrace.llmobs._constants import OUTPUT_MESSAGES
from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import SPAN_KIND
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._utils import _get_attr
from ddtrace.llmobs._utils import safe_json


class GeminiIntegration(BaseLLMIntegration):
_integration_name = "gemini"

def _set_base_span_tags(
self, span: Span, provider: Optional[str] = None, model: Optional[str] = None, **kwargs: Dict[str, Any]
) -> None:
if provider is not None:
span.set_tag_str("google_generativeai.request.provider", str(provider))
if model is not None:
span.set_tag_str("google_generativeai.request.model", str(model))

def llmobs_set_tags(
self, span: Span, args: List[Any], kwargs: Dict[str, Any], instance: Any, generations: Any = None
Yun-Kim marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
if not self.llmobs_enabled:
return

span.set_tag_str(SPAN_KIND, "llm")
span.set_tag_str(MODEL_NAME, span.get_tag("google_generativeai.request.model") or "")
span.set_tag_str(MODEL_PROVIDER, span.get_tag("google_generativeai.request.provider") or "")

metadata = self._llmobs_set_metadata(kwargs, instance)
span.set_tag_str(METADATA, safe_json(metadata))

system_instruction = _get_attr(instance, "_system_instruction", None)
input_contents = get_argument_value(args, kwargs, 0, "contents")
input_messages = self._extract_input_message(input_contents, system_instruction)
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))

if span.error or generations is None:
span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}]))
else:
output_messages = self._extract_output_message(generations)
span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages))

usage = self._get_llmobs_metrics_tags(span)
if usage:
span.set_tag_str(METRICS, safe_json(usage))

@staticmethod
def _llmobs_set_metadata(kwargs, instance):
metadata = {}
model_config = instance._generation_config or {}
request_config = kwargs.get("generation_config", {})
parameters = ("temperature", "max_output_tokens", "candidate_count", "top_p", "top_k")
for param in parameters:
model_config_value = _get_attr(model_config, param, None)
request_config_value = _get_attr(request_config, param, None)
if model_config_value or request_config_value:
metadata[param] = request_config_value or model_config_value
return metadata

@staticmethod
def _extract_message_from_part(part, role):
text = _get_attr(part, "text", "")
function_call = _get_attr(part, "function_call", None)
function_response = _get_attr(part, "function_response", None)
message = {"content": text}
if role:
message["role"] = role
if function_call:
function_call_dict = function_call
if not isinstance(function_call, dict):
function_call_dict = type(function_call).to_dict(function_call)
message["tool_calls"] = [
{"name": function_call_dict.get("name", ""), "arguments": function_call_dict.get("args", {})}
]
if function_response:
function_response_dict = function_response
if not isinstance(function_response, dict):
function_response_dict = type(function_response).to_dict(function_response)
message["content"] = "[tool result: {}]".format(function_response_dict.get("response", ""))
return message

def _extract_input_message(self, contents, system_instruction=None):
messages = []
if system_instruction:
for part in system_instruction.parts:
messages.append({"content": part.text or "", "role": "system"})
if isinstance(contents, str):
messages.append({"content": contents})
return messages
if isinstance(contents, dict):
message = {"content": contents.get("text", "")}
if contents.get("role", None):
message["role"] = contents["role"]
messages.append(message)
return messages
if not isinstance(contents, list):
messages.append({"content": "[Non-text content object: {}]".format(repr(contents))})
return messages
for content in contents:
if isinstance(content, str):
messages.append({"content": content})
continue
role = _get_attr(content, "role", None)
parts = _get_attr(content, "parts", [])
if not parts or not isinstance(parts, Iterable):
message = {"content": "[Non-text content object: {}]".format(repr(content))}
if role:
message["role"] = role
messages.append(message)
continue
for part in parts:
message = self._extract_message_from_part(part, role)
messages.append(message)
return messages

def _extract_output_message(self, generations):
output_messages = []
generations_dict = generations.to_dict()
for candidate in generations_dict.get("candidates", []):
content = candidate.get("content", {})
role = content.get("role", "model")
parts = content.get("parts", [])
for part in parts:
message = self._extract_message_from_part(part, role)
output_messages.append(message)
return output_messages

@staticmethod
def _get_llmobs_metrics_tags(span):
usage = {}
input_tokens = span.get_metric("google_generativeai.response.usage.prompt_tokens")
output_tokens = span.get_metric("google_generativeai.response.usage.completion_tokens")
total_tokens = span.get_metric("google_generativeai.response.usage.total_tokens")

if input_tokens is not None:
usage[INPUT_TOKENS_METRIC_KEY] = input_tokens
if output_tokens is not None:
usage[OUTPUT_TOKENS_METRIC_KEY] = output_tokens
if total_tokens is not None:
usage[TOTAL_TOKENS_METRIC_KEY] = total_tokens
return usage
Loading
Loading