[.Net] refactor over streaming version api (#2461)

* update

* update

* fix comment
This commit is contained in:
Xiaoyun Zhang
2024-05-05 07:51:00 -07:00
committed by GitHub
parent 4711d7bb9c
commit e878be55a3
39 changed files with 268 additions and 407 deletions

View File

@@ -23,7 +23,7 @@ public interface IMiddlewareAgent : IAgent
void Use(IMiddleware middleware);
}
public interface IMiddlewareStreamAgent : IMiddlewareAgent, IStreamingAgent
public interface IMiddlewareStreamAgent : IStreamingAgent
{
/// <summary>
/// Get the inner agent.
@@ -44,7 +44,11 @@ public interface IMiddlewareAgent<out T> : IMiddlewareAgent
T TAgent { get; }
}
public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent, IMiddlewareAgent<T>
public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent
where T : IStreamingAgent
{
/// <summary>
/// Get the typed inner agent.
/// </summary>
T TStreamingAgent { get; }
}

View File

@@ -3,7 +3,6 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
@@ -12,7 +11,7 @@ namespace AutoGen.Core;
/// </summary>
public interface IStreamingAgent : IAgent
{
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);

View File

@@ -14,7 +14,7 @@ namespace AutoGen.Core;
/// </summary>
public class MiddlewareAgent : IMiddlewareAgent
{
private readonly IAgent _agent;
private IAgent _agent;
private readonly List<IMiddleware> middlewares = new();
/// <summary>
@@ -22,10 +22,17 @@ public class MiddlewareAgent : IMiddlewareAgent
/// </summary>
/// <param name="innerAgent">the inner agent where middleware will be added.</param>
/// <param name="name">the name of the agent if provided. Otherwise, the name of <paramref name="innerAgent"/> will be used.</param>
public MiddlewareAgent(IAgent innerAgent, string? name = null)
public MiddlewareAgent(IAgent innerAgent, string? name = null, IEnumerable<IMiddleware>? middlewares = null)
{
this.Name = name ?? innerAgent.Name;
this._agent = innerAgent;
if (middlewares != null && middlewares.Any())
{
foreach (var middleware in middlewares)
{
this.Use(middleware);
}
}
}
/// <summary>
@@ -55,13 +62,7 @@ public class MiddlewareAgent : IMiddlewareAgent
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
IAgent agent = this._agent;
foreach (var middleware in this.middlewares)
{
agent = new DelegateAgent(middleware, agent);
}
return agent.GenerateReplyAsync(messages, options, cancellationToken);
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}
/// <summary>
@@ -71,15 +72,18 @@ public class MiddlewareAgent : IMiddlewareAgent
/// </summary>
public void Use(Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func, string? middlewareName = null)
{
this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
}));
});
this.Use(middleware);
}
public void Use(IMiddleware middleware)
{
this.middlewares.Add(middleware);
_agent = new DelegateAgent(middleware, _agent);
}
public override string ToString()

View File

@@ -2,33 +2,31 @@
// MiddlewareStreamingAgent.cs
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
public class MiddlewareStreamingAgent : IMiddlewareStreamAgent
{
private readonly IStreamingAgent _agent;
private IStreamingAgent _agent;
private readonly List<IStreamingMiddleware> _streamingMiddlewares = new();
private readonly List<IMiddleware> _middlewares = new();
public MiddlewareStreamingAgent(
IStreamingAgent agent,
string? name = null,
IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null,
IEnumerable<IMiddleware>? middlewares = null)
: base(agent, name)
IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null)
{
this.Name = name ?? agent.Name;
_agent = agent;
if (streamingMiddlewares != null)
{
_streamingMiddlewares.AddRange(streamingMiddlewares);
}
if (middlewares != null)
if (streamingMiddlewares != null && streamingMiddlewares.Any())
{
_middlewares.AddRange(middlewares);
foreach (var middleware in streamingMiddlewares)
{
this.UseStreaming(middleware);
}
}
}
@@ -42,26 +40,28 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
/// </summary>
public IEnumerable<IStreamingMiddleware> StreamingMiddlewares => _streamingMiddlewares;
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var agent = _agent;
foreach (var middleware in _streamingMiddlewares)
{
agent = new DelegateStreamingAgent(middleware, agent);
}
public string Name { get; }
return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
}
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
public void UseStreaming(IStreamingMiddleware middleware)
{
_streamingMiddlewares.Add(middleware);
_agent = new DelegateStreamingAgent(middleware, _agent);
}
private class DelegateStreamingAgent : IStreamingAgent
{
private IStreamingMiddleware? streamingMiddleware;
private IMiddleware? middleware;
private IStreamingAgent innerAgent;
public string Name => innerAgent.Name;
@@ -72,24 +72,19 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
this.innerAgent = next;
}
public DelegateStreamingAgent(IMiddleware middleware, IStreamingAgent next)
{
this.middleware = middleware;
this.innerAgent = next;
}
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (middleware is null)
if (this.streamingMiddleware is null)
{
return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
return innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}
var context = new MiddlewareContext(messages, options);
return await middleware.InvokeAsync(context, innerAgent, cancellationToken);
return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
}
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (streamingMiddleware is null)
{
@@ -105,20 +100,20 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
public sealed class MiddlewareStreamingAgent<T> : MiddlewareStreamingAgent, IMiddlewareStreamAgent<T>
where T : IStreamingAgent
{
public MiddlewareStreamingAgent(T innerAgent, string? name = null)
: base(innerAgent, name)
public MiddlewareStreamingAgent(T innerAgent, string? name = null, IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null)
: base(innerAgent, name, streamingMiddlewares)
{
TAgent = innerAgent;
TStreamingAgent = innerAgent;
}
public MiddlewareStreamingAgent(MiddlewareStreamingAgent<T> other)
: base(other)
{
TAgent = other.TAgent;
TStreamingAgent = other.TStreamingAgent;
}
/// <summary>
/// Get the inner agent.
/// </summary>
public T TAgent { get; }
public T TStreamingAgent { get; }
}

View File

@@ -20,6 +20,7 @@ public static class MiddlewareExtension
/// <param name="replyFunc"></param>
/// <returns></returns>
/// <exception cref="Exception">throw when agent name is null.</exception>
[Obsolete("Use RegisterMiddleware instead.")]
public static MiddlewareAgent<TAgent> RegisterReply<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, CancellationToken, Task<IMessage?>> replyFunc)
@@ -45,6 +46,7 @@ public static class MiddlewareExtension
/// One example is <see cref="PrintMessageMiddlewareExtension.RegisterPrintMessage{TAgent}(TAgent)" />, which print the formatted message to console before the agent return the reply.
/// </summary>
/// <exception cref="Exception">throw when agent name is null.</exception>
[Obsolete("Use RegisterMiddleware instead.")]
public static MiddlewareAgent<TAgent> RegisterPostProcess<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, IMessage, CancellationToken, Task<IMessage>> postprocessFunc)
@@ -62,6 +64,7 @@ public static class MiddlewareExtension
/// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply.
/// </summary>
/// <exception cref="Exception">throw when agent name is null.</exception>
[Obsolete("Use RegisterMiddleware instead.")]
public static MiddlewareAgent<TAgent> RegisterPreProcess<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, CancellationToken, Task<IEnumerable<IMessage>>> preprocessFunc)
@@ -77,6 +80,7 @@ public static class MiddlewareExtension
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this TAgent agent,
@@ -94,6 +98,7 @@ public static class MiddlewareExtension
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this TAgent agent,
@@ -107,6 +112,7 @@ public static class MiddlewareExtension
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this MiddlewareAgent<TAgent> agent,
@@ -124,6 +130,7 @@ public static class MiddlewareExtension
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this MiddlewareAgent<TAgent> agent,

View File

@@ -62,7 +62,7 @@ public static class PrintMessageMiddlewareExtension
{
var middleware = new PrintMessageMiddleware();
var middlewareAgent = new MiddlewareStreamingAgent<TAgent>(agent);
middlewareAgent.Use(middleware);
middlewareAgent.UseStreaming(middleware);
return middlewareAgent;
}

View File

@@ -1,17 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// StreamingMiddlewareExtension.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public static class StreamingMiddlewareExtension
{
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// Register an <see cref="IStreamingMiddleware"/> to an existing <see cref="IStreamingAgent"/> and return a new agent with the registered middleware.
/// For registering an <see cref="IMiddleware"/>, please refer to <see cref="MiddlewareExtension.RegisterMiddleware{TAgent}(MiddlewareAgent{TAgent}, IMiddleware)"/>
/// </summary>
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterStreamingMiddleware<TStreamingAgent>(
this TStreamingAgent agent,
@@ -21,16 +17,12 @@ public static class StreamingMiddlewareExtension
var middlewareAgent = new MiddlewareStreamingAgent<TStreamingAgent>(agent);
middlewareAgent.UseStreaming(middleware);
if (middleware is IMiddleware middlewareBase)
{
middlewareAgent.Use(middlewareBase);
}
return middlewareAgent;
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// Register an <see cref="IStreamingMiddleware"/> to an existing <see cref="IStreamingAgent"/> and return a new agent with the registered middleware.
/// For registering an <see cref="IMiddleware"/>, please refer to <see cref="MiddlewareExtension.RegisterMiddleware{TAgent}(MiddlewareAgent{TAgent}, IMiddleware)"/>
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
this MiddlewareStreamingAgent<TAgent> agent,
@@ -40,75 +32,6 @@ public static class StreamingMiddlewareExtension
var copyAgent = new MiddlewareStreamingAgent<TAgent>(agent);
copyAgent.UseStreaming(middleware);
if (middleware is IMiddleware middlewareBase)
{
copyAgent.Use(middlewareBase);
}
return copyAgent;
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
this TAgent agent,
Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task<IAsyncEnumerable<IStreamingMessage>>> func,
string? middlewareName = null)
where TAgent : IStreamingAgent
{
var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func));
return agent.RegisterStreamingMiddleware(middleware);
}
/// <summary>
/// Register a streaming middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
this MiddlewareStreamingAgent<TAgent> agent,
Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task<IAsyncEnumerable<IStreamingMessage>>> func,
string? middlewareName = null)
where TAgent : IStreamingAgent
{
var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func));
return agent.RegisterStreamingMiddleware(middleware);
}
/// <summary>
/// Register a middleware to an existing streaming agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterMiddleware<TStreamingAgent>(
this MiddlewareStreamingAgent<TStreamingAgent> streamingAgent,
Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func,
string? middlewareName = null)
where TStreamingAgent : IStreamingAgent
{
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
});
return streamingAgent.RegisterMiddleware(middleware);
}
/// <summary>
/// Register a middleware to an existing streaming agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterMiddleware<TStreamingAgent>(
this MiddlewareStreamingAgent<TStreamingAgent> streamingAgent,
IMiddleware middleware)
where TStreamingAgent : IStreamingAgent
{
var copyAgent = new MiddlewareStreamingAgent<TStreamingAgent>(streamingAgent);
copyAgent.Use(middleware);
if (middleware is IStreamingMiddleware streamingMiddleware)
{
copyAgent.UseStreaming(streamingMiddleware);
}
return copyAgent;
}
}

View File

@@ -1,38 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DelegateStreamingMiddleware.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
internal class DelegateStreamingMiddleware : IStreamingMiddleware
{
public delegate Task<IAsyncEnumerable<IStreamingMessage>> MiddlewareDelegate(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken);
private readonly MiddlewareDelegate middlewareDelegate;
public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDelegate)
{
this.Name = name;
this.middlewareDelegate = middlewareDelegate;
}
public string? Name { get; }
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var options = context.Options;
return this.middlewareDelegate(context, agent, cancellationToken);
}
}

View File

@@ -29,7 +29,7 @@ namespace AutoGen.Core;
/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function.
/// </para>
/// </summary>
public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
public class FunctionCallMiddleware : IStreamingMiddleware
{
private readonly IEnumerable<FunctionContract>? functions;
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
@@ -71,15 +71,10 @@ public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
return reply;
}
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
{
return Task.FromResult(this.StreamingInvokeAsync(context, agent, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> StreamingInvokeAsync(
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var lastMessage = context.Messages.Last();
if (lastMessage is ToolCallMessage toolCallMessage)
@@ -93,7 +88,7 @@ public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
options.Functions = combinedFunctions?.ToArray();
IStreamingMessage? initMessage = default;
await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{

View File

@@ -7,7 +7,7 @@ using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// The middleware interface
/// The middleware interface. For streaming-version middleware, check <see cref="IStreamingMiddleware"/>.
/// </summary>
public interface IMiddleware
{

View File

@@ -3,18 +3,18 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// The streaming middleware interface
/// The streaming middleware interface. For non-streaming version middleware, check <see cref="IMiddleware"/>.
/// </summary>
public interface IStreamingMiddleware
public interface IStreamingMiddleware : IMiddleware
{
public string? Name { get; }
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
/// <summary>
/// The streaming version of <see cref="IMiddleware.InvokeAsync(MiddlewareContext, IAgent, CancellationToken)"/>.
/// </summary>
public IAsyncEnumerable<IStreamingMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default);

View File

@@ -2,6 +2,8 @@
// PrintMessageMiddleware.cs
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -10,7 +12,7 @@ namespace AutoGen.Core;
/// <summary>
/// The middleware that prints the reply from agent to the console.
/// </summary>
public class PrintMessageMiddleware : IMiddleware
public class PrintMessageMiddleware : IStreamingMiddleware
{
public string? Name => nameof(PrintMessageMiddleware);
@@ -19,51 +21,12 @@ public class PrintMessageMiddleware : IMiddleware
if (agent is IStreamingAgent streamingAgent)
{
IMessage? recentUpdate = null;
await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
await foreach (var message in this.InvokeAsync(context, streamingAgent, cancellationToken))
{
if (message is TextMessageUpdate textMessageUpdate)
{
if (recentUpdate is null)
{
// Print from: xxx
Console.WriteLine($"from: {textMessageUpdate.From}");
recentUpdate = new TextMessage(textMessageUpdate);
Console.Write(textMessageUpdate.Content);
}
else if (recentUpdate is TextMessage recentTextMessage)
{
// Print the content of the message
Console.Write(textMessageUpdate.Content);
recentTextMessage.Update(textMessageUpdate);
}
else
{
throw new InvalidOperationException("The recent update is not a TextMessage");
}
}
else if (message is ToolCallMessageUpdate toolCallUpdate)
{
if (recentUpdate is null)
{
recentUpdate = new ToolCallMessage(toolCallUpdate);
}
else if (recentUpdate is ToolCallMessage recentToolCallMessage)
{
recentToolCallMessage.Update(toolCallUpdate);
}
else
{
throw new InvalidOperationException("The recent update is not a ToolCallMessage");
}
}
else if (message is IMessage imessage)
if (message is IMessage imessage)
{
recentUpdate = imessage;
}
else
{
throw new InvalidOperationException("The message is not a valid message");
}
}
Console.WriteLine();
if (recentUpdate is not null && recentUpdate is not TextMessage)
@@ -84,4 +47,72 @@ public class PrintMessageMiddleware : IMiddleware
return reply;
}
}
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IMessage? recentUpdate = null;
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
{
if (message is TextMessageUpdate textMessageUpdate)
{
if (recentUpdate is null)
{
// Print from: xxx
Console.WriteLine($"from: {textMessageUpdate.From}");
recentUpdate = new TextMessage(textMessageUpdate);
Console.Write(textMessageUpdate.Content);
yield return message;
}
else if (recentUpdate is TextMessage recentTextMessage)
{
// Print the content of the message
Console.Write(textMessageUpdate.Content);
recentTextMessage.Update(textMessageUpdate);
yield return recentTextMessage;
}
else
{
throw new InvalidOperationException("The recent update is not a TextMessage");
}
}
else if (message is ToolCallMessageUpdate toolCallUpdate)
{
if (recentUpdate is null)
{
recentUpdate = new ToolCallMessage(toolCallUpdate);
yield return message;
}
else if (recentUpdate is ToolCallMessage recentToolCallMessage)
{
recentToolCallMessage.Update(toolCallUpdate);
yield return message;
}
else
{
throw new InvalidOperationException("The recent update is not a ToolCallMessage");
}
}
else if (message is IMessage imessage)
{
recentUpdate = imessage;
yield return imessage;
}
else
{
throw new InvalidOperationException("The message is not a valid message");
}
}
Console.WriteLine();
if (recentUpdate is not null && recentUpdate is not TextMessage)
{
Console.WriteLine(recentUpdate.FormatMessage());
}
yield return recentUpdate ?? throw new InvalidOperationException("The message is not a valid message");
}
}

View File

@@ -28,19 +28,19 @@ public static class AgentExtension
string codeBlockSuffix = "```",
int maximumOutputToKeep = 500)
{
return agent.RegisterReply(async (msgs, ct) =>
return agent.RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
{
var lastMessage = msgs.LastOrDefault();
if (lastMessage == null || lastMessage.GetContent() is null)
{
return null;
return await innerAgent.GenerateReplyAsync(msgs, option, ct);
}
// retrieve all code blocks from last message
var codeBlocks = lastMessage.GetContent()!.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries);
if (codeBlocks.Length <= 0)
{
return null;
return await innerAgent.GenerateReplyAsync(msgs, option, ct);
}
// run code blocks

View File

@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
@@ -77,19 +78,14 @@ public class MistralClientAgent : IStreamingAgent
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = _client.StreamingChatCompletionsAsync(request);
return ProcessMessage(response);
}
private async IAsyncEnumerable<IMessage> ProcessMessage(IAsyncEnumerable<ChatCompletionResponse> response)
{
await foreach (var content in response)
{
yield return new MessageEnvelope<ChatCompletionResponse>(content, from: this.Name);

View File

@@ -18,9 +18,7 @@ public static class MistralAgentExtension
connector = new MistralChatMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector)
.RegisterMiddleware(connector);
return agent.RegisterStreamingMiddleware(connector);
}
/// <summary>
@@ -34,7 +32,6 @@ public static class MistralAgentExtension
connector = new MistralChatMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector)
.RegisterMiddleware(connector);
return agent.RegisterStreamingMiddleware(connector);
}
}

View File

@@ -15,17 +15,12 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
{
public string? Name => nameof(MistralChatMessageConnector);
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
{
return Task.FromResult(StreamingInvoke(context, agent, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> StreamingInvoke(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chunks = new List<ChatCompletionResponse>();
await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
{

View File

@@ -90,30 +90,28 @@ public class GPTAgent : IStreamingAgent
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var agent = this._innerAgent
.RegisterMessageConnector();
if (this.functionMap is not null)
{
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
agent = agent.RegisterMiddleware(functionMapMiddleware);
}
return await agent.GenerateReplyAsync(messages, options, cancellationToken);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var agent = this._innerAgent
.RegisterMessageConnector();
var agent = this._innerAgent.RegisterMessageConnector();
if (this.functionMap is not null)
{
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
}
return await agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
return await agent.GenerateReplyAsync(messages, options, cancellationToken);
}
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var agent = this._innerAgent.RegisterMessageConnector();
if (this.functionMap is not null)
{
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
}
return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
}

View File

@@ -87,15 +87,7 @@ public class OpenAIChatAgent : IStreamingAgent
return new MessageEnvelope<ChatCompletions>(reply, from: this.Name);
}
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
return Task.FromResult(this.StreamingReplyAsync(messages, options, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> StreamingReplyAsync(
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)

View File

@@ -44,22 +44,14 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
return PostProcessMessage(reply);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default)
{
return InvokeStreamingAsync(context, agent, cancellationToken);
}
private async IAsyncEnumerable<IStreamingMessage> InvokeStreamingAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessages = ProcessIncomingMessages(agent, context.Messages)
.Select(m => new MessageEnvelope<ChatRequestMessage>(m));
var streamingReply = await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
string? currentToolName = null;
await foreach (var reply in streamingReply)
{
@@ -135,6 +127,12 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
private IMessage PostProcessMessage(IMessage<ChatCompletions> message)
{
// throw exception if prompt filter results is not null
if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered)
{
throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
}
return PostProcessMessage(message.Content.Choices[0].Message, message.From);
}

View File

@@ -47,20 +47,12 @@ public class SemanticKernelChatMessageContentConnector : IMiddleware, IStreaming
return PostProcessMessage(reply);
}
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
{
return Task.FromResult(InvokeStreamingAsync(context, agent, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> InvokeStreamingAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken)
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatMessageContents = ProcessMessage(context.Messages, agent)
.Select(m => new MessageEnvelope<ChatMessageContent>(m));
await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken))
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken))
{
yield return PostProcessStreamingMessage(reply);
}

View File

@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
@@ -64,17 +65,25 @@ public class SemanticKernelAgent : IStreamingAgent
return new MessageEnvelope<ChatMessageContent>(reply.First(), from: this.Name);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatHistory = BuildChatHistory(messages);
var option = BuildOption(options);
var chatService = _kernel.GetRequiredService<IChatCompletionService>();
var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken);
return ProcessMessage(response);
await foreach (var content in response)
{
if (content.ChoiceIndex > 0)
{
throw new InvalidOperationException("Only one choice is supported in streaming response");
}
yield return new MessageEnvelope<StreamingChatMessageContent>(content, from: this.Name);
}
}
private ChatHistory BuildChatHistory(IEnumerable<IMessage> messages)
@@ -101,19 +110,6 @@ public class SemanticKernelAgent : IStreamingAgent
};
}
private async IAsyncEnumerable<IMessage> ProcessMessage(IAsyncEnumerable<StreamingChatMessageContent> response)
{
await foreach (var content in response)
{
if (content.ChoiceIndex > 0)
{
throw new InvalidOperationException("Only one choice is supported in streaming response");
}
yield return new MessageEnvelope<StreamingChatMessageContent>(content, from: this.Name);
}
}
private IEnumerable<ChatMessageContent> ProcessMessage(IEnumerable<IMessage> messages)
{
return messages.Select(m => m switch

View File

@@ -79,19 +79,33 @@ public class ConversableAgent : IAgent
IAgent? agent = null;
foreach (var llmConfig in config.ConfigList ?? Enumerable.Empty<ILLMConfig>())
{
agent = agent switch
var nextAgent = llmConfig switch
{
null => llmConfig switch
{
AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0),
OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0),
_ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"),
},
IAgent innerAgent => innerAgent.RegisterReply(async (messages, cancellationToken) =>
{
return await innerAgent.GenerateReplyAsync(messages, cancellationToken: cancellationToken);
}),
AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0),
OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0),
_ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"),
};
if (agent == null)
{
agent = nextAgent;
}
else
{
agent = agent.RegisterMiddleware(async (messages, option, agent, cancellationToken) =>
{
var agentResponse = await nextAgent.GenerateReplyAsync(messages, option, cancellationToken: cancellationToken);
if (agentResponse is null)
{
return await agent.GenerateReplyAsync(messages, option, cancellationToken);
}
else
{
return agentResponse;
}
});
}
}
return agent;

View File

@@ -20,7 +20,6 @@
<ProjectReference Include="..\AutoGen.Mistral\AutoGen.Mistral.csproj" />
<ProjectReference Include="..\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj" />
<ProjectReference Include="..\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
<ItemGroup>