Skip to content

Commit

Permalink
Python: Fix schema handling. Fix function result return for type list. (
Browse files Browse the repository at this point in the history
#6370)

### Motivation and Context

Building the tools json payload from the kernel parameter metadata
wasn't properly including an object of type `array`.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

Correctly include the object type `array` so that the tool call doesn't
return a bad request. Add unit tests.
- Closes #6367 
- Closes #6360
- Fixes the FunctionResult return for a type string -- if the
FunctionResult is of type KernelContent then return the first element of
the list, otherwise return the complete list.
- Fix the kernel function from method to include the proper type_object
for the return parameter so that the schema can be created properly.
- Add retry logic for a sometimes flaky function calling stepwise
planner integration test.
- Add a check during function calling that makes sure the model is
returning the proper number of arguments based on how many function
arguments are required.

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 committed May 23, 2024
1 parent e98cd18 commit 0c95173
Show file tree
Hide file tree
Showing 24 changed files with 1,090 additions and 677 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,22 @@ async def _process_function_call(
chat_history.add_message(message=frc.to_chat_message_content())
return

num_required_func_params = len([param for param in function_to_call.parameters if param.is_required])
if len(parsed_args) < num_required_func_params:
msg = (
f"There are `{num_required_func_params}` tool call arguments required and "
f"only `{len(parsed_args)}` received. The required arguments are: "
f"{[param.name for param in function_to_call.parameters if param.is_required]}. "
"Please provide the required arguments and try again."
)
logger.exception(msg)
frc = FunctionResultContent.from_function_call_content_and_result(
function_call_content=function_call,
result=msg,
)
chat_history.add_message(message=frc.to_chat_message_content())
return

_rebuild_auto_function_invocation_context()
invocation_context = AutoFunctionInvocationContext(
function=function_to_call,
Expand Down
29 changes: 20 additions & 9 deletions python/semantic_kernel/connectors/ai/open_ai/services/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,30 @@ def kernel_function_metadata_to_openai_tool_format(metadata: KernelFunctionMetad

def parse_schema(schema_data):
"""Recursively parse the schema data to include nested properties."""
if schema_data.get("type") == "object":
if schema_data is None:
return {"type": "string", "description": ""}

schema_type = schema_data.get("type")
schema_description = schema_data.get("description", "")

if schema_type == "object":
properties = {key: parse_schema(value) for key, value in schema_data.get("properties", {}).items()}
return {
"type": "object",
"properties": {key: parse_schema(value) for key, value in schema_data.get("properties", {}).items()},
"description": schema_data.get("description", ""),
}
else:
return {
"type": schema_data.get("type", "string"),
"description": schema_data.get("description", ""),
**({"enum": schema_data.get("enum")} if "enum" in schema_data else {}),
"properties": properties,
"description": schema_description,
}

if schema_type == "array":
items = schema_data.get("items", {"type": "string"})
return {"type": "array", "description": schema_description, "items": items}

schema_dict = {"type": schema_type, "description": schema_description}
if "enum" in schema_data:
schema_dict["enum"] = schema_data["enum"]

return schema_dict

return {
"type": "function",
"function": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (c) Microsoft. All rights reserved.

import re
from typing import Any
from urllib.parse import urlencode, urljoin, urlparse, urlunparse

from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_expected_response import (
RestApiOperationExpectedResponse,
)
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_parameter import RestApiOperationParameter
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_parameter_location import (
RestApiOperationParameterLocation,
)
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_parameter_style import (
RestApiOperationParameterStyle,
)
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_payload import RestApiOperationPayload
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_payload_property import (
RestApiOperationPayloadProperty,
)
from semantic_kernel.exceptions.function_exceptions import FunctionExecutionException
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class RestApiOperation:
MEDIA_TYPE_TEXT_PLAIN = "text/plain"
PAYLOAD_ARGUMENT_NAME = "payload"
CONTENT_TYPE_ARGUMENT_NAME = "content-type"
INVALID_SYMBOLS_REGEX = re.compile(r"[^0-9A-Za-z_]+")

_preferred_responses: list[str] = [
"200",
"201",
"202",
"203",
"204",
"205",
"206",
"207",
"208",
"226",
"2XX",
"default",
]

def __init__(
self,
id: str,
method: str,
server_url: str,
path: str,
summary: str | None = None,
description: str | None = None,
params: list["RestApiOperationParameter"] | None = None,
request_body: "RestApiOperationPayload | None" = None,
responses: dict[str, "RestApiOperationExpectedResponse"] | None = None,
):
self.id = id
self.method = method.upper()
self.server_url = server_url
self.path = path
self.summary = summary
self.description = description
self.parameters = params
self.request_body = request_body
self.responses = responses

def url_join(self, base_url: str, path: str):
"""Join a base URL and a path, correcting for any missing slashes."""
parsed_base = urlparse(base_url)
if not parsed_base.path.endswith("/"):
base_path = parsed_base.path + "/"
else:
base_path = parsed_base.path
full_path = urljoin(base_path, path.lstrip("/"))
return urlunparse(parsed_base._replace(path=full_path))

def build_headers(self, arguments: dict[str, Any]) -> dict[str, str]:
headers = {}

parameters = [p for p in self.parameters if p.location == RestApiOperationParameterLocation.HEADER]

for parameter in parameters:
argument = arguments.get(parameter.name)

if argument is None:
if parameter.is_required:
raise FunctionExecutionException(
f"No argument is provided for the `{parameter.name}` "
f"required parameter of the operation - `{self.id}`."
)
continue

headers[parameter.name] = str(argument)

return headers

def build_operation_url(self, arguments, server_url_override=None, api_host_url=None):
server_url = self.get_server_url(server_url_override, api_host_url)
path = self.build_path(self.path, arguments)
return urljoin(server_url.geturl(), path.lstrip("/"))

def get_server_url(self, server_url_override=None, api_host_url=None):
if server_url_override is not None and server_url_override.geturl() != b"":
server_url_string = server_url_override.geturl()
else:
server_url_string = (
self.server_url.geturl()
if self.server_url
else api_host_url.geturl() if api_host_url else self._raise_invalid_operation_exception()
)

# make sure the base URL ends with a trailing slash
if not server_url_string.endswith("/"):
server_url_string += "/"

return urlparse(server_url_string)

def build_path(self, path_template: str, arguments: dict[str, Any]) -> str:
parameters = [p for p in self.parameters if p.location == RestApiOperationParameterLocation.PATH]
for parameter in parameters:
argument = arguments.get(parameter.name)
if argument is None:
if parameter.is_required:
raise FunctionExecutionException(
f"No argument is provided for the `{parameter.name}` "
f"required parameter of the operation - `{self.id}`."
)
continue
path_template = path_template.replace(f"{{{parameter.name}}}", str(argument))
return path_template

def build_query_string(self, arguments: dict[str, Any]) -> str:
segments = []
parameters = [p for p in self.parameters if p.location == RestApiOperationParameterLocation.QUERY]
for parameter in parameters:
argument = arguments.get(parameter.name)
if argument is None:
if parameter.is_required:
raise FunctionExecutionException(
f"No argument or value is provided for the `{parameter.name}` "
f"required parameter of the operation - `{self.id}`."
)
continue
segments.append((parameter.name, argument))
return urlencode(segments)

def replace_invalid_symbols(self, parameter_name):
return RestApiOperation.INVALID_SYMBOLS_REGEX.sub("_", parameter_name)

def get_parameters(
self,
operation: "RestApiOperation",
add_payload_params_from_metadata: bool = True,
enable_payload_spacing: bool = False,
) -> list["RestApiOperationParameter"]:
params = list(operation.parameters)
if operation.request_body is not None:
params.extend(
self.get_payload_parameters(
operation=operation,
use_parameters_from_metadata=add_payload_params_from_metadata,
enable_namespacing=enable_payload_spacing,
)
)

for parameter in params:
parameter.alternative_name = self.replace_invalid_symbols(parameter.name)

return params

def create_payload_artificial_parameter(self, operation: "RestApiOperation") -> "RestApiOperationParameter":
return RestApiOperationParameter(
name=self.PAYLOAD_ARGUMENT_NAME,
type=(
"string"
if operation.request_body
and operation.request_body.media_type == RestApiOperation.MEDIA_TYPE_TEXT_PLAIN
else "object"
),
is_required=True,
location=RestApiOperationParameterLocation.BODY,
style=RestApiOperationParameterStyle.SIMPLE,
description=operation.request_body.description if operation.request_body else "REST API request body.",
schema=operation.request_body.schema if operation.request_body else None,
)

def create_content_type_artificial_parameter(self) -> "RestApiOperationParameter":
return RestApiOperationParameter(
name=self.CONTENT_TYPE_ARGUMENT_NAME,
type="string",
is_required=False,
location=RestApiOperationParameterLocation.BODY,
style=RestApiOperationParameterStyle.SIMPLE,
description="Content type of REST API request body.",
)

def _get_property_name(
self, property: RestApiOperationPayloadProperty, root_property_name: bool, enable_namespacing: bool
):
if enable_namespacing and root_property_name:
return f"{root_property_name}.{property.name}"
return property.name

def _get_parameters_from_payload_metadata(
self,
properties: list["RestApiOperationPayloadProperty"],
enable_namespacing: bool = False,
root_property_name: bool = None,
) -> list["RestApiOperationParameter"]:
parameters: list[RestApiOperationParameter] = []
for property in properties:
parameter_name = self._get_property_name(property, root_property_name, enable_namespacing)
if not property.properties:
parameters.append(
RestApiOperationParameter(
name=parameter_name,
type=property.type,
is_required=property.is_required,
location=RestApiOperationParameterLocation.BODY,
style=RestApiOperationParameterStyle.SIMPLE,
description=property.description,
schema=property.schema,
)
)
parameters.extend(
self._get_parameters_from_payload_metadata(property.properties, enable_namespacing, parameter_name)
)
return parameters

def get_payload_parameters(
self, operation: "RestApiOperation", use_parameters_from_metadata: bool, enable_namespacing: bool
):
if use_parameters_from_metadata:
if operation.request_body is None:
raise Exception(
f"Payload parameters cannot be retrieved from the `{operation.Id}` "
f"operation payload metadata because it is missing."
)
if operation.request_body.media_type == RestApiOperation.MEDIA_TYPE_TEXT_PLAIN:
return [self.create_payload_artificial_parameter(operation)]

return self._get_parameters_from_payload_metadata(operation.request_body.properties, enable_namespacing)

return [
self.create_payload_artificial_parameter(operation),
self.create_content_type_artificial_parameter(operation),
]

def get_default_response(
self, responses: dict[str, RestApiOperationExpectedResponse], preferred_responses: list[str]
) -> RestApiOperationExpectedResponse | None:
for code in preferred_responses:
if code in responses:
return responses[code]
# If no appropriate response is found, return None
return None

def get_default_return_parameter(self, preferred_responses: list[str] | None = None) -> KernelParameterMetadata:
if preferred_responses is None:
preferred_responses = self._preferred_responses

rest_operation_response = self.get_default_response(self.responses, preferred_responses)

if rest_operation_response:
return KernelParameterMetadata(
name="return",
description=rest_operation_response.description,
type_=rest_operation_response.schema.get("type") if rest_operation_response.schema else None,
schema_data=rest_operation_response.schema,
)

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.


from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class RestApiOperationExpectedResponse:
def __init__(self, description: str, media_type: str, schema: str | None = None):
self.description = description
self.media_type = media_type
self.schema = schema
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Any

from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_expected_response import (
RestApiOperationExpectedResponse,
)
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_parameter_location import (
RestApiOperationParameterLocation,
)
from semantic_kernel.connectors.openapi_plugin.models.rest_api_operation_parameter_style import (
RestApiOperationParameterStyle,
)
from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class RestApiOperationParameter:
def __init__(
self,
name: str,
type: str,
location: RestApiOperationParameterLocation,
style: RestApiOperationParameterStyle | None = None,
alternative_name: str | None = None,
description: str | None = None,
is_required: bool = False,
default_value: Any | None = None,
schema: str | None = None,
response: RestApiOperationExpectedResponse | None = None,
):
self.name = name
self.type = type
self.location = location
self.style = style
self.alternative_name = alternative_name
self.description = description
self.is_required = is_required
self.default_value = default_value
self.schema = schema
self.response = response
Loading

0 comments on commit 0c95173

Please sign in to comment.