From e878be55a3663b7864bc0ef8b9526e2f0be2f88f Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Sun, 5 May 2024 07:51:00 -0700 Subject: [PATCH] [.Net] refactor over streaming version api (#2461) * update * update * fix comment --- .../CodeSnippet/AgentCodeSnippet.cs | 2 +- .../CodeSnippet/BuildInMessageCodeSnippet.cs | 4 +- .../CodeSnippet/MistralAICodeSnippet.cs | 4 +- .../CodeSnippet/OpenAICodeSnippet.cs | 4 +- .../CodeSnippet/SemanticKernelCodeSnippet.cs | 2 +- .../Example02_TwoAgent_MathChat.cs | 7 +- ...7_Dynamic_GroupChat_Calculate_Fibonacci.cs | 21 +--- .../Example10_SemanticKernel.cs | 2 +- .../Example13_OpenAIAgent_JsonMode.cs | 3 +- dotnet/sample/AutoGen.BasicSamples/Program.cs | 3 +- .../AutoGen.Core/Agent/IMiddlewareAgent.cs | 8 +- .../src/AutoGen.Core/Agent/IStreamingAgent.cs | 3 +- .../src/AutoGen.Core/Agent/MiddlewareAgent.cs | 26 ++-- .../Agent/MiddlewareStreamingAgent.cs | 67 +++++----- .../Extension/MiddlewareExtension.cs | 7 ++ .../PrintMessageMiddlewareExtension.cs | 2 +- .../Extension/StreamingMiddlewareExtension.cs | 85 +------------ .../Middleware/DelegateStreamingMiddleware.cs | 38 ------ .../Middleware/FunctionCallMiddleware.cs | 13 +- .../AutoGen.Core/Middleware/IMiddleware.cs | 2 +- .../Middleware/IStreamingMiddleware.cs | 12 +- .../Middleware/PrintMessageMiddleware.cs | 115 +++++++++++------- .../Extension/AgentExtension.cs | 6 +- .../Agent/MistralClientAgent.cs | 10 +- .../Extension/MistralAgentExtension.cs | 7 +- .../Middleware/MistralChatMessageConnector.cs | 9 +- dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 36 +++--- .../AutoGen.OpenAI/Agent/OpenAIChatAgent.cs | 10 +- .../OpenAIChatRequestMessageConnector.cs | 20 ++- ...manticKernelChatMessageContentConnector.cs | 12 +- .../SemanticKernelAgent.cs | 28 ++--- dotnet/src/AutoGen/Agent/ConversableAgent.cs | 36 ++++-- dotnet/src/AutoGen/AutoGen.csproj | 1 - .../MistralClientAgentTests.cs | 12 +- .../test/AutoGen.Tests/OpenAIChatAgentTest.cs | 13 +- .../AutoGen.Tests/RegisterReplyAgentTest.cs | 27 ---- .../AutoGen.Tests/SemanticKernelAgentTest.cs | 9 +- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 4 +- dotnet/website/update.md | 5 +- 39 files changed, 268 insertions(+), 407 deletions(-) delete mode 100644 dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs delete mode 100644 dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs index df45e4bfe..abaf94cbd 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs @@ -19,7 +19,7 @@ internal class AgentCodeSnippet #region ChatWithAnAgent_GenerateStreamingReplyAsync var textMessage = new TextMessage(Role.User, "Hello"); - await foreach (var streamingReply in await agent.GenerateStreamingReplyAsync([message])) + await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message])) { if (streamingReply is TextMessageUpdate update) { diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs index b272ba23a..f26485116 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs @@ -11,7 +11,7 @@ internal class BuildInMessageCodeSnippet IStreamingAgent agent = default; #region StreamingCallCodeSnippet var helloTextMessage = new TextMessage(Role.User, "Hello"); - var reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]); + var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]); var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name); await foreach (var message in reply) { @@ -24,7 +24,7 @@ internal class BuildInMessageCodeSnippet #endregion StreamingCallCodeSnippet #region StreamingCallWithFinalMessage - reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]); + reply = agent.GenerateStreamingReplyAsync([helloTextMessage]); TextMessage finalMessage = null; await foreach (var message in reply) { diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs index 6bb9e9107..cd49810dc 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs @@ -38,7 +38,7 @@ internal class MistralAICodeSnippet #endregion create_mistral_agent #region streaming_chat - var reply = await agent.GenerateStreamingReplyAsync( + var reply = agent.GenerateStreamingReplyAsync( messages: [new TextMessage(Role.User, "Hello, how are you?")] ); @@ -75,7 +75,7 @@ internal class MistralAICodeSnippet #endregion create_get_weather_function_call_middleware #region register_function_call_middleware - agent = agent.RegisterMiddleware(functionCallMiddleware); + agent = agent.RegisterStreamingMiddleware(functionCallMiddleware); #endregion register_function_call_middleware #region send_message_with_function_call diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs index 8d129e751..022f7e9f9 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs @@ -60,7 +60,7 @@ public partial class OpenAICodeSnippet #endregion create_openai_chat_agent #region create_openai_chat_agent_streaming - var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { @@ -123,7 +123,7 @@ public partial class OpenAICodeSnippet { functions.GetWeatherFunctionContract.Name, functions.GetWeatherWrapper } // GetWeatherWrapper is a wrapper function for GetWeather, which is also auto-generated }); - openAIChatAgent = openAIChatAgent.RegisterMiddleware(functionCallMiddleware); + openAIChatAgent = openAIChatAgent.RegisterStreamingMiddleware(functionCallMiddleware); #endregion create_function_call_middleware #region chat_agent_send_function_call diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs index 77f93fdf4..b0366eb2b 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs @@ -49,7 +49,7 @@ public class SemanticKernelCodeSnippet #endregion create_semantic_kernel_agent #region create_semantic_kernel_agent_streaming - var streamingReply = await semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs index 8d42b9d05..f20b0848a 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs @@ -18,16 +18,17 @@ public static class Example02_TwoAgent_MathChat var teacher = new AssistantAgent( name: "teacher", systemMessage: @"You are a teacher that create pre-school math question for student and check answer. - If the answer is correct, you terminate conversation by saying [TERMINATE]. + If the answer is correct, you stop the conversation by saying [COMPLETE]. If the answer is wrong, you ask student to fix it.", llmConfig: new ConversableAgentConfig { Temperature = 0, ConfigList = [gpt35], }) - .RegisterPostProcess(async (_, reply, _) => + .RegisterMiddleware(async (msgs, option, agent, _) => { - if (reply.GetContent()?.ToLower().Contains("terminate") is true) + var reply = await agent.GenerateReplyAsync(msgs, option); + if (reply.GetContent()?.ToLower().Contains("complete") is true) { return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From); } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index 6b1dc0965..89e6f45f8 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -85,26 +85,16 @@ public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci systemMessage: "You run dotnet code", defaultReply: "No code available.") .RegisterDotnetCodeBlockExectionHook(interactiveService: service) - .RegisterReply(async (msgs, _) => + .RegisterMiddleware(async (msgs, option, agent, _) => { - if (msgs.Count() == 0) + if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder")) { return new TextMessage(Role.Assistant, "No code available. Coder please write code"); } - - return null; - }) - .RegisterPreProcess(async (msgs, _) => - { - // retrieve the most recent message from coder - var coderMsg = msgs.LastOrDefault(msg => msg.From == "coder"); - if (coderMsg is null) - { - return Enumerable.Empty(); - } else { - return new[] { coderMsg }; + var coderMsg = msgs.Last(msg => msg.From == "coder"); + return await agent.GenerateReplyAsync([coderMsg], option); } }) .RegisterPrintMessage(); @@ -122,8 +112,9 @@ public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer", temperature: 0, config: gpt3Config) - .RegisterPostProcess(async (_, reply, _) => + .RegisterMiddleware(async (msgs, option, agent, _) => { + var reply = await agent.GenerateReplyAsync(msgs, option); if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true) { var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}"; diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs index e4ef7de9d..61c341204 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs @@ -62,7 +62,7 @@ public class Example10_SemanticKernel Console.WriteLine((reply as IMessage).Content.Items[0].As().Text); var skAgentWithMiddleware = skAgent - .RegisterMessageConnector() + .RegisterMessageConnector() // Register the message connector to support more AutoGen built-in message types .RegisterPrintMessage(); // Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage diff --git a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs index 2591ab230..35b7b7d1d 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs @@ -28,7 +28,8 @@ public class Example13_OpenAIAgent_JsonMode systemMessage: "You are a helpful assistant designed to output JSON.", seed: 0, // explicitly set a seed to enable deterministic output responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode - .RegisterMessageConnector(); + .RegisterMessageConnector() + .RegisterPrintMessage(); #endregion create_agent #region chat_with_agent diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index bddbb68bf..fb0bacbb5 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Program.cs -using AutoGen.BasicSample; -await Example14_MistralClientAgent_TokenCount.RunAsync(); +await Example02_TwoAgent_MathChat.RunAsync(); diff --git a/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs index 7b318183d..a0b01e7c3 100644 --- a/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs @@ -23,7 +23,7 @@ public interface IMiddlewareAgent : IAgent void Use(IMiddleware middleware); } -public interface IMiddlewareStreamAgent : IMiddlewareAgent, IStreamingAgent +public interface IMiddlewareStreamAgent : IStreamingAgent { /// /// Get the inner agent. @@ -44,7 +44,11 @@ public interface IMiddlewareAgent : IMiddlewareAgent T TAgent { get; } } -public interface IMiddlewareStreamAgent : IMiddlewareStreamAgent, IMiddlewareAgent +public interface IMiddlewareStreamAgent : IMiddlewareStreamAgent where T : IStreamingAgent { + /// + /// Get the typed inner agent. + /// + T TStreamingAgent { get; } } diff --git a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs index f4004b139..665f18bac 100644 --- a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace AutoGen.Core; @@ -12,7 +11,7 @@ namespace AutoGen.Core; /// public interface IStreamingAgent : IAgent { - public Task> GenerateStreamingReplyAsync( + public IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs index 307e0da79..84d0d4b59 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs @@ -14,7 +14,7 @@ namespace AutoGen.Core; /// public class MiddlewareAgent : IMiddlewareAgent { - private readonly IAgent _agent; + private IAgent _agent; private readonly List middlewares = new(); /// @@ -22,10 +22,17 @@ public class MiddlewareAgent : IMiddlewareAgent /// /// the inner agent where middleware will be added. /// the name of the agent if provided. Otherwise, the name of will be used. - public MiddlewareAgent(IAgent innerAgent, string? name = null) + public MiddlewareAgent(IAgent innerAgent, string? name = null, IEnumerable? middlewares = null) { this.Name = name ?? innerAgent.Name; this._agent = innerAgent; + if (middlewares != null && middlewares.Any()) + { + foreach (var middleware in middlewares) + { + this.Use(middleware); + } + } } /// @@ -55,13 +62,7 @@ public class MiddlewareAgent : IMiddlewareAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - IAgent agent = this._agent; - foreach (var middleware in this.middlewares) - { - agent = new DelegateAgent(middleware, agent); - } - - return agent.GenerateReplyAsync(messages, options, cancellationToken); + return _agent.GenerateReplyAsync(messages, options, cancellationToken); } /// @@ -71,15 +72,18 @@ public class MiddlewareAgent : IMiddlewareAgent /// public void Use(Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) { - this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => + var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => { return await func(context.Messages, context.Options, agent, cancellationToken); - })); + }); + + this.Use(middleware); } public void Use(IMiddleware middleware) { this.middlewares.Add(middleware); + _agent = new DelegateAgent(middleware, _agent); } public override string ToString() diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index b83922227..251d3c110 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -2,33 +2,31 @@ // MiddlewareStreamingAgent.cs using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; namespace AutoGen.Core; -public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent +public class MiddlewareStreamingAgent : IMiddlewareStreamAgent { - private readonly IStreamingAgent _agent; + private IStreamingAgent _agent; private readonly List _streamingMiddlewares = new(); - private readonly List _middlewares = new(); public MiddlewareStreamingAgent( IStreamingAgent agent, string? name = null, - IEnumerable? streamingMiddlewares = null, - IEnumerable? middlewares = null) - : base(agent, name) + IEnumerable? streamingMiddlewares = null) { + this.Name = name ?? agent.Name; _agent = agent; - if (streamingMiddlewares != null) - { - _streamingMiddlewares.AddRange(streamingMiddlewares); - } - if (middlewares != null) + if (streamingMiddlewares != null && streamingMiddlewares.Any()) { - _middlewares.AddRange(middlewares); + foreach (var middleware in streamingMiddlewares) + { + this.UseStreaming(middleware); + } } } @@ -42,26 +40,28 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent /// public IEnumerable StreamingMiddlewares => _streamingMiddlewares; - public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) - { - var agent = _agent; - foreach (var middleware in _streamingMiddlewares) - { - agent = new DelegateStreamingAgent(middleware, agent); - } + public string Name { get; } - return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); + public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + return _agent.GenerateReplyAsync(messages, options, cancellationToken); + } + + public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + + return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } public void UseStreaming(IStreamingMiddleware middleware) { _streamingMiddlewares.Add(middleware); + _agent = new DelegateStreamingAgent(middleware, _agent); } private class DelegateStreamingAgent : IStreamingAgent { private IStreamingMiddleware? streamingMiddleware; - private IMiddleware? middleware; private IStreamingAgent innerAgent; public string Name => innerAgent.Name; @@ -72,24 +72,19 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent this.innerAgent = next; } - public DelegateStreamingAgent(IMiddleware middleware, IStreamingAgent next) - { - this.middleware = middleware; - this.innerAgent = next; - } - public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - if (middleware is null) + if (this.streamingMiddleware is null) { - return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken); + return innerAgent.GenerateReplyAsync(messages, options, cancellationToken); } var context = new MiddlewareContext(messages, options); - return await middleware.InvokeAsync(context, innerAgent, cancellationToken); + return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken); } - public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { if (streamingMiddleware is null) { @@ -105,20 +100,20 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent public sealed class MiddlewareStreamingAgent : MiddlewareStreamingAgent, IMiddlewareStreamAgent where T : IStreamingAgent { - public MiddlewareStreamingAgent(T innerAgent, string? name = null) - : base(innerAgent, name) + public MiddlewareStreamingAgent(T innerAgent, string? name = null, IEnumerable? streamingMiddlewares = null) + : base(innerAgent, name, streamingMiddlewares) { - TAgent = innerAgent; + TStreamingAgent = innerAgent; } public MiddlewareStreamingAgent(MiddlewareStreamingAgent other) : base(other) { - TAgent = other.TAgent; + TStreamingAgent = other.TStreamingAgent; } /// /// Get the inner agent. /// - public T TAgent { get; } + public T TStreamingAgent { get; } } diff --git a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs index c522c78f5..5beed7fd8 100644 --- a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs @@ -20,6 +20,7 @@ public static class MiddlewareExtension /// /// /// throw when agent name is null. + [Obsolete("Use RegisterMiddleware instead.")] public static MiddlewareAgent RegisterReply( this TAgent agent, Func, CancellationToken, Task> replyFunc) @@ -45,6 +46,7 @@ public static class MiddlewareExtension /// One example is , which print the formatted message to console before the agent return the reply. /// /// throw when agent name is null. + [Obsolete("Use RegisterMiddleware instead.")] public static MiddlewareAgent RegisterPostProcess( this TAgent agent, Func, IMessage, CancellationToken, Task> postprocessFunc) @@ -62,6 +64,7 @@ public static class MiddlewareExtension /// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply. /// /// throw when agent name is null. + [Obsolete("Use RegisterMiddleware instead.")] public static MiddlewareAgent RegisterPreProcess( this TAgent agent, Func, CancellationToken, Task>> preprocessFunc) @@ -77,6 +80,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this TAgent agent, @@ -94,6 +98,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this TAgent agent, @@ -107,6 +112,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this MiddlewareAgent agent, @@ -124,6 +130,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this MiddlewareAgent agent, diff --git a/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs index deb196ca3..262b50d12 100644 --- a/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs @@ -62,7 +62,7 @@ public static class PrintMessageMiddlewareExtension { var middleware = new PrintMessageMiddleware(); var middlewareAgent = new MiddlewareStreamingAgent(agent); - middlewareAgent.Use(middleware); + middlewareAgent.UseStreaming(middleware); return middlewareAgent; } diff --git a/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs index 901d7f249..2ec7b3f9f 100644 --- a/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs @@ -1,17 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // StreamingMiddlewareExtension.cs -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - namespace AutoGen.Core; public static class StreamingMiddlewareExtension { /// - /// Register a middleware to an existing agent and return a new agent with the middleware. + /// Register an to an existing and return a new agent with the registered middleware. + /// For registering an , please refer to /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this TStreamingAgent agent, @@ -21,16 +17,12 @@ public static class StreamingMiddlewareExtension var middlewareAgent = new MiddlewareStreamingAgent(agent); middlewareAgent.UseStreaming(middleware); - if (middleware is IMiddleware middlewareBase) - { - middlewareAgent.Use(middlewareBase); - } - return middlewareAgent; } /// - /// Register a middleware to an existing agent and return a new agent with the middleware. + /// Register an to an existing and return a new agent with the registered middleware. + /// For registering an , please refer to /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this MiddlewareStreamingAgent agent, @@ -40,75 +32,6 @@ public static class StreamingMiddlewareExtension var copyAgent = new MiddlewareStreamingAgent(agent); copyAgent.UseStreaming(middleware); - if (middleware is IMiddleware middlewareBase) - { - copyAgent.Use(middlewareBase); - } - - return copyAgent; - } - - - /// - /// Register a middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this TAgent agent, - Func>> func, - string? middlewareName = null) - where TAgent : IStreamingAgent - { - var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func)); - - return agent.RegisterStreamingMiddleware(middleware); - } - - /// - /// Register a streaming middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this MiddlewareStreamingAgent agent, - Func>> func, - string? middlewareName = null) - where TAgent : IStreamingAgent - { - var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func)); - - return agent.RegisterStreamingMiddleware(middleware); - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddleware( - this MiddlewareStreamingAgent streamingAgent, - Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, - string? middlewareName = null) - where TStreamingAgent : IStreamingAgent - { - var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => - { - return await func(context.Messages, context.Options, agent, cancellationToken); - }); - - return streamingAgent.RegisterMiddleware(middleware); - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddleware( - this MiddlewareStreamingAgent streamingAgent, - IMiddleware middleware) - where TStreamingAgent : IStreamingAgent - { - var copyAgent = new MiddlewareStreamingAgent(streamingAgent); - copyAgent.Use(middleware); - if (middleware is IStreamingMiddleware streamingMiddleware) - { - copyAgent.UseStreaming(streamingMiddleware); - } - return copyAgent; } } diff --git a/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs deleted file mode 100644 index 5499abccf..000000000 --- a/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// DelegateStreamingMiddleware.cs - -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace AutoGen.Core; - -internal class DelegateStreamingMiddleware : IStreamingMiddleware -{ - public delegate Task> MiddlewareDelegate( - MiddlewareContext context, - IStreamingAgent agent, - CancellationToken cancellationToken); - - private readonly MiddlewareDelegate middlewareDelegate; - - public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDelegate) - { - this.Name = name; - this.middlewareDelegate = middlewareDelegate; - } - - public string? Name { get; } - - public Task> InvokeAsync( - MiddlewareContext context, - IStreamingAgent agent, - CancellationToken cancellationToken = default) - { - var messages = context.Messages; - var options = context.Options; - - return this.middlewareDelegate(context, agent, cancellationToken); - } -} - diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs index d00151b32..2bc028055 100644 --- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs @@ -29,7 +29,7 @@ namespace AutoGen.Core; /// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function. /// /// -public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware +public class FunctionCallMiddleware : IStreamingMiddleware { private readonly IEnumerable? functions; private readonly IDictionary>>? functionMap; @@ -71,15 +71,10 @@ public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware return reply; } - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) - { - return Task.FromResult(this.StreamingInvokeAsync(context, agent, cancellationToken)); - } - - private async IAsyncEnumerable StreamingInvokeAsync( + public async IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var lastMessage = context.Messages.Last(); if (lastMessage is ToolCallMessage toolCallMessage) @@ -93,7 +88,7 @@ public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware options.Functions = combinedFunctions?.ToArray(); IStreamingMessage? initMessage = default; - await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) + await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) { if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null) { diff --git a/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs index 2813ee9cd..00ec5a97f 100644 --- a/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs @@ -7,7 +7,7 @@ using System.Threading.Tasks; namespace AutoGen.Core; /// -/// The middleware interface +/// The middleware interface. For streaming-version middleware, check . /// public interface IMiddleware { diff --git a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs index b8965dcc4..bc7aec57f 100644 --- a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs @@ -3,18 +3,18 @@ using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace AutoGen.Core; /// -/// The streaming middleware interface +/// The streaming middleware interface. For non-streaming version middleware, check . /// -public interface IStreamingMiddleware +public interface IStreamingMiddleware : IMiddleware { - public string? Name { get; } - - public Task> InvokeAsync( + /// + /// The streaming version of . + /// + public IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs index 9461b6973..099f78e5f 100644 --- a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs @@ -2,6 +2,8 @@ // PrintMessageMiddleware.cs using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -10,7 +12,7 @@ namespace AutoGen.Core; /// /// The middleware that prints the reply from agent to the console. /// -public class PrintMessageMiddleware : IMiddleware +public class PrintMessageMiddleware : IStreamingMiddleware { public string? Name => nameof(PrintMessageMiddleware); @@ -19,51 +21,12 @@ public class PrintMessageMiddleware : IMiddleware if (agent is IStreamingAgent streamingAgent) { IMessage? recentUpdate = null; - await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken)) + await foreach (var message in this.InvokeAsync(context, streamingAgent, cancellationToken)) { - if (message is TextMessageUpdate textMessageUpdate) - { - if (recentUpdate is null) - { - // Print from: xxx - Console.WriteLine($"from: {textMessageUpdate.From}"); - recentUpdate = new TextMessage(textMessageUpdate); - Console.Write(textMessageUpdate.Content); - } - else if (recentUpdate is TextMessage recentTextMessage) - { - // Print the content of the message - Console.Write(textMessageUpdate.Content); - recentTextMessage.Update(textMessageUpdate); - } - else - { - throw new InvalidOperationException("The recent update is not a TextMessage"); - } - } - else if (message is ToolCallMessageUpdate toolCallUpdate) - { - if (recentUpdate is null) - { - recentUpdate = new ToolCallMessage(toolCallUpdate); - } - else if (recentUpdate is ToolCallMessage recentToolCallMessage) - { - recentToolCallMessage.Update(toolCallUpdate); - } - else - { - throw new InvalidOperationException("The recent update is not a ToolCallMessage"); - } - } - else if (message is IMessage imessage) + if (message is IMessage imessage) { recentUpdate = imessage; } - else - { - throw new InvalidOperationException("The message is not a valid message"); - } } Console.WriteLine(); if (recentUpdate is not null && recentUpdate is not TextMessage) @@ -84,4 +47,72 @@ public class PrintMessageMiddleware : IMiddleware return reply; } } + + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + IMessage? recentUpdate = null; + await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken)) + { + if (message is TextMessageUpdate textMessageUpdate) + { + if (recentUpdate is null) + { + // Print from: xxx + Console.WriteLine($"from: {textMessageUpdate.From}"); + recentUpdate = new TextMessage(textMessageUpdate); + Console.Write(textMessageUpdate.Content); + + yield return message; + } + else if (recentUpdate is TextMessage recentTextMessage) + { + // Print the content of the message + Console.Write(textMessageUpdate.Content); + recentTextMessage.Update(textMessageUpdate); + + yield return recentTextMessage; + } + else + { + throw new InvalidOperationException("The recent update is not a TextMessage"); + } + } + else if (message is ToolCallMessageUpdate toolCallUpdate) + { + if (recentUpdate is null) + { + recentUpdate = new ToolCallMessage(toolCallUpdate); + + yield return message; + } + else if (recentUpdate is ToolCallMessage recentToolCallMessage) + { + recentToolCallMessage.Update(toolCallUpdate); + + yield return message; + } + else + { + throw new InvalidOperationException("The recent update is not a ToolCallMessage"); + } + } + else if (message is IMessage imessage) + { + recentUpdate = imessage; + + yield return imessage; + } + else + { + throw new InvalidOperationException("The message is not a valid message"); + } + } + Console.WriteLine(); + if (recentUpdate is not null && recentUpdate is not TextMessage) + { + Console.WriteLine(recentUpdate.FormatMessage()); + } + + yield return recentUpdate ?? throw new InvalidOperationException("The message is not a valid message"); + } } diff --git a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs index 034ca170e..83955c53f 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs @@ -28,19 +28,19 @@ public static class AgentExtension string codeBlockSuffix = "```", int maximumOutputToKeep = 500) { - return agent.RegisterReply(async (msgs, ct) => + return agent.RegisterMiddleware(async (msgs, option, innerAgent, ct) => { var lastMessage = msgs.LastOrDefault(); if (lastMessage == null || lastMessage.GetContent() is null) { - return null; + return await innerAgent.GenerateReplyAsync(msgs, option, ct); } // retrieve all code blocks from last message var codeBlocks = lastMessage.GetContent()!.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries); if (codeBlocks.Length <= 0) { - return null; + return await innerAgent.GenerateReplyAsync(msgs, option, ct); } // run code blocks diff --git a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs index 2ba28bbb7..cc2c74145 100644 --- a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs +++ b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using AutoGen.Core; @@ -77,19 +78,14 @@ public class MistralClientAgent : IStreamingAgent return new MessageEnvelope(response, from: this.Name); } - public async Task> GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var request = BuildChatRequest(messages, options); var response = _client.StreamingChatCompletionsAsync(request); - return ProcessMessage(response); - } - - private async IAsyncEnumerable ProcessMessage(IAsyncEnumerable response) - { await foreach (var content in response) { yield return new MessageEnvelope(content, from: this.Name); diff --git a/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs b/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs index 5b3c998b6..787393d06 100644 --- a/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs +++ b/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs @@ -18,9 +18,7 @@ public static class MistralAgentExtension connector = new MistralChatMessageConnector(); } - return agent.RegisterStreamingMiddleware(connector) - .RegisterMiddleware(connector); - + return agent.RegisterStreamingMiddleware(connector); } /// @@ -34,7 +32,6 @@ public static class MistralAgentExtension connector = new MistralChatMessageConnector(); } - return agent.RegisterStreamingMiddleware(connector) - .RegisterMiddleware(connector); + return agent.RegisterStreamingMiddleware(connector); } } diff --git a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs index 44f34401e..3ba910aa7 100644 --- a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs +++ b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs @@ -15,17 +15,12 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware { public string? Name => nameof(MistralChatMessageConnector); - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) - { - return Task.FromResult(StreamingInvoke(context, agent, cancellationToken)); - } - - private async IAsyncEnumerable StreamingInvoke(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var messages = context.Messages; var chatMessages = ProcessMessage(messages, agent); var chunks = new List(); - await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) + await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) { if (reply is IStreamingMessage chatMessage) { diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs index cb5a97c13..52070788e 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -90,30 +90,28 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var agent = this._innerAgent - .RegisterMessageConnector(); - if (this.functionMap is not null) - { - var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); - agent = agent.RegisterMiddleware(functionMapMiddleware); - } - - return await agent.GenerateReplyAsync(messages, options, cancellationToken); - } - - public async Task> GenerateStreamingReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - var agent = this._innerAgent - .RegisterMessageConnector(); + var agent = this._innerAgent.RegisterMessageConnector(); if (this.functionMap is not null) { var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); agent = agent.RegisterStreamingMiddleware(functionMapMiddleware); } - return await agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); + return await agent.GenerateReplyAsync(messages, options, cancellationToken); + } + + public IAsyncEnumerable GenerateStreamingReplyAsync( + IEnumerable messages, + GenerateReplyOptions? options = null, + CancellationToken cancellationToken = default) + { + var agent = this._innerAgent.RegisterMessageConnector(); + if (this.functionMap is not null) + { + var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); + agent = agent.RegisterStreamingMiddleware(functionMapMiddleware); + } + + return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } } diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index 487a361d7..37a4882f6 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -87,15 +87,7 @@ public class OpenAIChatAgent : IStreamingAgent return new MessageEnvelope(reply, from: this.Name); } - public Task> GenerateStreamingReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - return Task.FromResult(this.StreamingReplyAsync(messages, options, cancellationToken)); - } - - private async IAsyncEnumerable StreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 118d99703..2bd9470ff 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -44,22 +44,14 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa return PostProcessMessage(reply); } - public async Task> InvokeAsync( + public async IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, - CancellationToken cancellationToken = default) - { - return InvokeStreamingAsync(context, agent, cancellationToken); - } - - private async IAsyncEnumerable InvokeStreamingAsync( - MiddlewareContext context, - IStreamingAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chatMessages = ProcessIncomingMessages(agent, context.Messages) .Select(m => new MessageEnvelope(m)); - var streamingReply = await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); + var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); string? currentToolName = null; await foreach (var reply in streamingReply) { @@ -135,6 +127,12 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa private IMessage PostProcessMessage(IMessage message) { + // throw exception if prompt filter results is not null + if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered) + { + throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input."); + } + return PostProcessMessage(message.Content.Choices[0].Message, message.From); } diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs index 557683c96..6a8395ef2 100644 --- a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs +++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs @@ -47,20 +47,12 @@ public class SemanticKernelChatMessageContentConnector : IMiddleware, IStreaming return PostProcessMessage(reply); } - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) - { - return Task.FromResult(InvokeStreamingAsync(context, agent, cancellationToken)); - } - - private async IAsyncEnumerable InvokeStreamingAsync( - MiddlewareContext context, - IStreamingAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chatMessageContents = ProcessMessage(context.Messages, agent) .Select(m => new MessageEnvelope(m)); - await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken)) + await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken)) { yield return PostProcessStreamingMessage(reply); } diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index b887a6ef5..21f652f56 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -64,17 +65,25 @@ public class SemanticKernelAgent : IStreamingAgent return new MessageEnvelope(reply.First(), from: this.Name); } - public async Task> GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chatHistory = BuildChatHistory(messages); var option = BuildOption(options); var chatService = _kernel.GetRequiredService(); var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); - return ProcessMessage(response); + await foreach (var content in response) + { + if (content.ChoiceIndex > 0) + { + throw new InvalidOperationException("Only one choice is supported in streaming response"); + } + + yield return new MessageEnvelope(content, from: this.Name); + } } private ChatHistory BuildChatHistory(IEnumerable messages) @@ -101,19 +110,6 @@ public class SemanticKernelAgent : IStreamingAgent }; } - private async IAsyncEnumerable ProcessMessage(IAsyncEnumerable response) - { - await foreach (var content in response) - { - if (content.ChoiceIndex > 0) - { - throw new InvalidOperationException("Only one choice is supported in streaming response"); - } - - yield return new MessageEnvelope(content, from: this.Name); - } - } - private IEnumerable ProcessMessage(IEnumerable messages) { return messages.Select(m => m switch diff --git a/dotnet/src/AutoGen/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Agent/ConversableAgent.cs index e70a74a80..d79d25192 100644 --- a/dotnet/src/AutoGen/Agent/ConversableAgent.cs +++ b/dotnet/src/AutoGen/Agent/ConversableAgent.cs @@ -79,19 +79,33 @@ public class ConversableAgent : IAgent IAgent? agent = null; foreach (var llmConfig in config.ConfigList ?? Enumerable.Empty()) { - agent = agent switch + var nextAgent = llmConfig switch { - null => llmConfig switch - { - AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0), - OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0), - _ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"), - }, - IAgent innerAgent => innerAgent.RegisterReply(async (messages, cancellationToken) => - { - return await innerAgent.GenerateReplyAsync(messages, cancellationToken: cancellationToken); - }), + AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0), + OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0), + _ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"), }; + + if (agent == null) + { + agent = nextAgent; + } + else + { + agent = agent.RegisterMiddleware(async (messages, option, agent, cancellationToken) => + { + var agentResponse = await nextAgent.GenerateReplyAsync(messages, option, cancellationToken: cancellationToken); + + if (agentResponse is null) + { + return await agent.GenerateReplyAsync(messages, option, cancellationToken); + } + else + { + return agentResponse; + } + }); + } } return agent; diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj index 2b9aaed6d..8f4bbccb5 100644 --- a/dotnet/src/AutoGen/AutoGen.csproj +++ b/dotnet/src/AutoGen/AutoGen.csproj @@ -20,7 +20,6 @@ - diff --git a/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs b/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs index 110e81fdb..5a9d1f95c 100644 --- a/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs +++ b/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs @@ -114,7 +114,7 @@ public partial class MistralClientAgentTests model: "mistral-small-latest", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var functionCallMiddlewareExecutorMiddleware = new FunctionCallMiddleware( functionMap: new Dictionary>> @@ -127,7 +127,7 @@ public partial class MistralClientAgentTests model: "mistral-small-latest", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddlewareExecutorMiddleware); + .RegisterStreamingMiddleware(functionCallMiddlewareExecutorMiddleware); await twoAgentTest.TwoAgentGetWeatherFunctionCallTestAsync(executorAgent, functionCallAgent); } @@ -148,7 +148,7 @@ public partial class MistralClientAgentTests model: "mistral-small-latest", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var question = new TextMessage(Role.User, "what's the weather in Seattle?"); var reply = await functionCallAgent.SendAsync(question); @@ -193,7 +193,7 @@ public partial class MistralClientAgentTests toolChoice: ToolChoiceEnum.Any, randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); await singleAgentTest.EchoFunctionCallExecutionTestAsync(agent); await singleAgentTest.EchoFunctionCallExecutionStreamingTestAsync(agent); } @@ -214,7 +214,7 @@ public partial class MistralClientAgentTests systemMessage: "You are a helpful assistant that can call functions", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); await singleAgentTest.EchoFunctionCallTestAsync(agent); @@ -222,7 +222,7 @@ public partial class MistralClientAgentTests var question = new TextMessage(Role.User, "what's the weather in Seattle?"); IMessage? finalReply = null; - await foreach (var reply in await agent.GenerateStreamingReplyAsync([question])) + await foreach (var reply in agent.GenerateStreamingReplyAsync([question])) { reply.From.Should().Be(agent.Name); if (reply is IMessage message) diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs index a4753b668..c504eb06a 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs @@ -47,7 +47,7 @@ public partial class OpenAIChatAgentTest reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0); // test streaming - var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { @@ -93,7 +93,7 @@ public partial class OpenAIChatAgentTest // test streaming foreach (var message in messages) { - var reply = await assistant.GenerateStreamingReplyAsync([message]); + var reply = assistant.GenerateStreamingReplyAsync([message]); await foreach (var streamingMessage in reply) { @@ -119,10 +119,9 @@ public partial class OpenAIChatAgentTest MiddlewareStreamingAgent assistant = openAIChatAgent .RegisterMessageConnector(); - assistant.Middlewares.Count().Should().Be(1); assistant.StreamingMiddlewares.Count().Should().Be(1); var functionCallAgent = assistant - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var question = "What's the weather in Seattle"; var messages = new IMessage[] @@ -150,7 +149,7 @@ public partial class OpenAIChatAgentTest // test streaming foreach (var message in messages) { - var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]); + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); ToolCallMessage? toolCallMessage = null; await foreach (var streamingMessage in reply) { @@ -191,7 +190,7 @@ public partial class OpenAIChatAgentTest .RegisterMessageConnector(); var functionCallAgent = assistant - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var question = "What's the weather in Seattle"; var messages = new IMessage[] @@ -220,7 +219,7 @@ public partial class OpenAIChatAgentTest // test streaming foreach (var message in messages) { - var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]); + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); await foreach (var streamingMessage in reply) { if (streamingMessage is not IMessage) diff --git a/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs b/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs deleted file mode 100644 index d4866ad87..000000000 --- a/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// RegisterReplyAgentTest.cs - -using System.Threading.Tasks; -using FluentAssertions; -using Xunit; - -namespace AutoGen.Tests -{ - public class RegisterReplyAgentTest - { - [Fact] - public async Task RegisterReplyTestAsync() - { - IAgent echoAgent = new EchoAgent("echo"); - echoAgent = echoAgent - .RegisterReply(async (conversations, ct) => new TextMessage(Role.Assistant, "I'm your father", from: echoAgent.Name)); - - var msg = new Message(Role.User, "hey"); - var reply = await echoAgent.SendAsync(msg); - reply.Should().BeOfType(); - reply.GetContent().Should().Be("I'm your father"); - reply.GetRole().Should().Be(Role.Assistant); - reply.From.Should().Be("echo"); - } - } -} diff --git a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs index 2e5b56f80..dcb5cd47b 100644 --- a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs @@ -44,7 +44,7 @@ public partial class SemanticKernelAgentTest reply.As>().From.Should().Be("assistant"); // test streaming - var streamingReply = await skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { @@ -63,10 +63,8 @@ public partial class SemanticKernelAgentTest var kernel = builder.Build(); - var connector = new SemanticKernelChatMessageContentConnector(); var skAgent = new SemanticKernelAgent(kernel, "assistant") - .RegisterStreamingMiddleware(connector) - .RegisterMiddleware(connector); + .RegisterMessageConnector(); var messages = new IMessage[] { @@ -90,7 +88,7 @@ public partial class SemanticKernelAgentTest // test streaming foreach (var message in messages) { - var reply = await skAgent.GenerateStreamingReplyAsync([message]); + var reply = skAgent.GenerateStreamingReplyAsync([message]); await foreach (var streamingMessage in reply) { @@ -122,7 +120,6 @@ public partial class SemanticKernelAgentTest var skAgent = new SemanticKernelAgent(kernel, "assistant") .RegisterMessageConnector(); - skAgent.Middlewares.Count().Should().Be(1); skAgent.StreamingMiddlewares.Count().Should().Be(1); var question = "What is the weather in Seattle?"; diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 6dfb61761..ae566889b 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -261,7 +261,7 @@ namespace AutoGen.Tests { Temperature = 0, }; - var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); + var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); var answer = "[ECHO] Hello world"; IStreamingMessage? finalReply = default; await foreach (var reply in replyStream) @@ -302,7 +302,7 @@ namespace AutoGen.Tests { Temperature = 0, }; - var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); + var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); var answer = "A B C D E F G H I J K L M N"; TextMessage? finalReply = default; await foreach (var reply in replyStream) diff --git a/dotnet/website/update.md b/dotnet/website/update.md index a97b94805..b65ab128e 100644 --- a/dotnet/website/update.md +++ b/dotnet/website/update.md @@ -1,6 +1,9 @@ +##### Update +- [API Breaking Change] Update the return type of `IStreamingAgent.GenerateStreamingReplyAsync` from `Task>` to `IAsyncEnumerable` +- [API Breaking Change] Update the return type of `IStreamingMiddleware.InvokeAsync` from `Task>` to `IAsyncEnumerable` +- [API Breaking Change] Mark `RegisterReply`, `RegisterPreProcess` and `RegisterPostProcess` as obsolete. You can replace them with `RegisterMiddleware` ##### Update on 0.0.12 (2024-04-22) - Add AutoGen.Mistral package to support Mistral.AI models - ##### Update on 0.0.11 (2024-04-10) - Add link to Discord channel in nuget's readme.md - Document improvements