Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Added IChatHistoryReducer for ChatCompletion with AzureOpenAI and OpenAI #8894

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private async Task InvokeAgentAsync(ChatCompletionAgent agent, int messageCount)
Console.WriteLine($"# {AuthorRole.User}: '{index}'");

// Reduce prior to invoking the agent
bool isReduced = await agent.ReduceAsync(chat);
bool isReduced = await agent.TryReduceAsync(chat);

// Invoke and display assistant response
await foreach (ChatMessageContent message in agent.InvokeAsync(chat))
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/Agents/Core/ChatHistoryChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public sealed class ChatHistoryChannel : AgentChannel
}

// Pre-process history reduction.
await historyAgent.ReduceAsync(this._history, cancellationToken).ConfigureAwait(false);
await historyAgent.TryReduceAsync(this._history, cancellationToken).ConfigureAwait(false);

// Capture the current message count to evaluate history mutation.
int messageCount = this._history.Count;
Expand Down Expand Up @@ -85,7 +85,7 @@ protected override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStr
}

// Pre-process history reduction.
await historyAgent.ReduceAsync(this._history, cancellationToken).ConfigureAwait(false);
await historyAgent.TryReduceAsync(this._history, cancellationToken).ConfigureAwait(false);

int messageCount = this._history.Count;

Expand Down
10 changes: 5 additions & 5 deletions dotnet/src/Agents/Core/ChatHistoryKernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public abstract class ChatHistoryKernelAgent : KernelAgent
public KernelArguments? Arguments { get; init; }

/// <inheritdoc/>
public IChatHistoryReducer? HistoryReducer { get; init; }
public IAgentChatHistoryReducer? HistoryReducer { get; init; }

/// <inheritdoc/>
public abstract IAsyncEnumerable<ChatMessageContent> InvokeAsync(
Expand All @@ -45,9 +45,9 @@ public abstract IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsy
/// </summary>
/// <param name="history">The source history</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns></returns>
public Task<bool> ReduceAsync(ChatHistory history, CancellationToken cancellationToken = default) =>
history.ReduceAsync(this.HistoryReducer, cancellationToken);
/// <returns>A boolean indicating if the operation was successful or not.</returns>
public Task<bool> TryReduceAsync(ChatHistory history, CancellationToken cancellationToken = default) =>
history.TryReduceAsync(this.HistoryReducer, cancellationToken);

/// <inheritdoc/>
protected sealed override IEnumerable<string> GetChannelKeys()
Expand All @@ -59,7 +59,7 @@ protected sealed override IEnumerable<string> GetChannelKeys()
if (this.HistoryReducer != null)
{
// Explicitly include the reducer type to eliminate the possibility of hash collisions
// with custom implementations of IChatHistoryReducer.
// with custom implementations of IAgentChatHistoryReducer.
yield return this.HistoryReducer.GetType().FullName!;

yield return this.HistoryReducer.GetHashCode().ToString(CultureInfo.InvariantCulture);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public static int LocateSafeReductionIndex(this IReadOnlyList<ChatMessageContent
/// Using the existing <see cref="ChatHistory"/> for a reduction in collection size eliminates the need
/// for re-allocation (of memory).
/// </remarks>
public static async Task<bool> ReduceAsync(this ChatHistory history, IChatHistoryReducer? reducer, CancellationToken cancellationToken)
public static async Task<bool> TryReduceAsync(this ChatHistory history, IAgentChatHistoryReducer? reducer, CancellationToken cancellationToken)
{
if (reducer == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.SemanticKernel.Agents.History;
/// is provided (recommended), reduction will scan within the threshold window in an attempt to
/// avoid orphaning a user message from an assistant response.
/// </remarks>
public class ChatHistorySummarizationReducer : IChatHistoryReducer
public class ChatHistorySummarizationReducer : IAgentChatHistoryReducer
{
/// <summary>
/// Metadata key to indicate a summary message.
Expand Down Expand Up @@ -64,21 +64,21 @@ Provide a concise and complete summarization of the entire dialog that does not
public bool UseSingleSummary { get; init; } = true;

/// <inheritdoc/>
public async Task<IEnumerable<ChatMessageContent>?> ReduceAsync(IReadOnlyList<ChatMessageContent> history, CancellationToken cancellationToken = default)
public async Task<IEnumerable<ChatMessageContent>?> ReduceAsync(IReadOnlyList<ChatMessageContent> chatHistory, CancellationToken cancellationToken = default)
{
// Identify where summary messages end and regular history begins
int insertionPoint = history.LocateSummarizationBoundary(SummaryMetadataKey);
int insertionPoint = chatHistory.LocateSummarizationBoundary(SummaryMetadataKey);

// First pass to determine the truncation index
int truncationIndex = history.LocateSafeReductionIndex(this._targetCount, this._thresholdCount, insertionPoint);
int truncationIndex = chatHistory.LocateSafeReductionIndex(this._targetCount, this._thresholdCount, insertionPoint);

IEnumerable<ChatMessageContent>? truncatedHistory = null;

if (truncationIndex > 0)
{
// Second pass to extract history for summarization
IEnumerable<ChatMessageContent> summarizedHistory =
history.Extract(
chatHistory.Extract(
this.UseSingleSummary ? 0 : insertionPoint,
truncationIndex - 1,
(m) => m.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent));
Expand Down Expand Up @@ -111,7 +111,7 @@ IEnumerable<ChatMessageContent> AssemblySummarizedHistory(ChatMessageContent? su
{
for (int index = 0; index <= insertionPoint - 1; ++index)
{
yield return history[index];
yield return chatHistory[index];
}
}

Expand All @@ -120,9 +120,9 @@ IEnumerable<ChatMessageContent> AssemblySummarizedHistory(ChatMessageContent? su
yield return summary;
}

for (int index = truncationIndex; index < history.Count; ++index)
for (int index = truncationIndex; index < chatHistory.Count; ++index)
{
yield return history[index];
yield return chatHistory[index];
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ namespace Microsoft.SemanticKernel.Agents.History;
/// is provided (recommended), reduction will scan within the threshold window in an attempt to
/// avoid orphaning a user message from an assistant response.
/// </remarks>
public class ChatHistoryTruncationReducer : IChatHistoryReducer
public class ChatHistoryTruncationReducer : IAgentChatHistoryReducer
{
/// <inheritdoc/>
public Task<IEnumerable<ChatMessageContent>?> ReduceAsync(IReadOnlyList<ChatMessageContent> history, CancellationToken cancellationToken = default)
public Task<IEnumerable<ChatMessageContent>?> ReduceAsync(IReadOnlyList<ChatMessageContent> chatHistory, CancellationToken cancellationToken = default)
{
// First pass to determine the truncation index
int truncationIndex = history.LocateSafeReductionIndex(this._targetCount, this._thresholdCount);
int truncationIndex = chatHistory.LocateSafeReductionIndex(this._targetCount, this._thresholdCount);

IEnumerable<ChatMessageContent>? truncatedHistory = null;

if (truncationIndex > 0)
{
// Second pass to truncate the history
truncatedHistory = history.Extract(truncationIndex);
truncatedHistory = chatHistory.Extract(truncationIndex);
}

return Task.FromResult(truncatedHistory);
Expand Down
25 changes: 25 additions & 0 deletions dotnet/src/Agents/Core/History/IAgentChatHistoryReducer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents.History;

/// <summary>
/// Defines a contract for a reducing chat history to be used by agents.
/// </summary>
/// <remarks>
/// The additional interface methods are used to evaluate the equality of different reducers, which is necessary for the agent channel key.
/// </remarks>
public interface IAgentChatHistoryReducer : IChatHistoryReducer
{
/// <summary>
/// Each reducer shall override equality evaluation so that different reducers
/// of the same configuration can be evaluated for equivalency.
/// </summary>
bool Equals(object? obj);

/// <summary>
/// Each reducer shall implement custom hash-code generation so that different reducers
/// of the same configuration can be evaluated for equivalency.
/// </summary>
int GetHashCode();
}
32 changes: 0 additions & 32 deletions dotnet/src/Agents/Core/History/IChatHistoryReducer.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ public async Task VerifyChatHistoryNotReducedAsync()
{
// Arrange
ChatHistory history = [];
Mock<IChatHistoryReducer> mockReducer = new();
Mock<IAgentChatHistoryReducer> mockReducer = new();
mockReducer.Setup(r => r.ReduceAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), default)).ReturnsAsync((IEnumerable<ChatMessageContent>?)null);

// Act
bool isReduced = await history.ReduceAsync(null, default);
bool isReduced = await history.TryReduceAsync(null, default);

// Assert
Assert.False(isReduced);
Assert.Empty(history);

// Act
isReduced = await history.ReduceAsync(mockReducer.Object, default);
isReduced = await history.TryReduceAsync(mockReducer.Object, default);

// Assert
Assert.False(isReduced);
Expand All @@ -110,13 +110,13 @@ public async Task VerifyChatHistoryNotReducedAsync()
public async Task VerifyChatHistoryReducedAsync()
{
// Arrange
Mock<IChatHistoryReducer> mockReducer = new();
Mock<IAgentChatHistoryReducer> mockReducer = new();
mockReducer.Setup(r => r.ReduceAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), default)).ReturnsAsync((IEnumerable<ChatMessageContent>?)[]);

ChatHistory history = [.. MockHistoryGenerator.CreateSimpleHistory(10)];

// Act
bool isReduced = await history.ReduceAsync(mockReducer.Object, default);
bool isReduced = await history.TryReduceAsync(mockReducer.Object, default);

// Assert
Assert.True(isReduced);
Expand Down
Loading
Loading