mirror of
https://github.com/microsoft/autogen.git
synced 2026-05-13 03:00:55 -04:00
[.Net] refactor over streaming version api (#2461)
* update * update * fix comment
This commit is contained in:
@@ -19,7 +19,7 @@ internal class AgentCodeSnippet
|
||||
|
||||
#region ChatWithAnAgent_GenerateStreamingReplyAsync
|
||||
var textMessage = new TextMessage(Role.User, "Hello");
|
||||
await foreach (var streamingReply in await agent.GenerateStreamingReplyAsync([message]))
|
||||
await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message]))
|
||||
{
|
||||
if (streamingReply is TextMessageUpdate update)
|
||||
{
|
||||
|
||||
@@ -11,7 +11,7 @@ internal class BuildInMessageCodeSnippet
|
||||
IStreamingAgent agent = default;
|
||||
#region StreamingCallCodeSnippet
|
||||
var helloTextMessage = new TextMessage(Role.User, "Hello");
|
||||
var reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
|
||||
var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
|
||||
var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name);
|
||||
await foreach (var message in reply)
|
||||
{
|
||||
@@ -24,7 +24,7 @@ internal class BuildInMessageCodeSnippet
|
||||
#endregion StreamingCallCodeSnippet
|
||||
|
||||
#region StreamingCallWithFinalMessage
|
||||
reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]);
|
||||
reply = agent.GenerateStreamingReplyAsync([helloTextMessage]);
|
||||
TextMessage finalMessage = null;
|
||||
await foreach (var message in reply)
|
||||
{
|
||||
|
||||
@@ -38,7 +38,7 @@ internal class MistralAICodeSnippet
|
||||
#endregion create_mistral_agent
|
||||
|
||||
#region streaming_chat
|
||||
var reply = await agent.GenerateStreamingReplyAsync(
|
||||
var reply = agent.GenerateStreamingReplyAsync(
|
||||
messages: [new TextMessage(Role.User, "Hello, how are you?")]
|
||||
);
|
||||
|
||||
@@ -75,7 +75,7 @@ internal class MistralAICodeSnippet
|
||||
#endregion create_get_weather_function_call_middleware
|
||||
|
||||
#region register_function_call_middleware
|
||||
agent = agent.RegisterMiddleware(functionCallMiddleware);
|
||||
agent = agent.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
#endregion register_function_call_middleware
|
||||
|
||||
#region send_message_with_function_call
|
||||
|
||||
@@ -60,7 +60,7 @@ public partial class OpenAICodeSnippet
|
||||
#endregion create_openai_chat_agent
|
||||
|
||||
#region create_openai_chat_agent_streaming
|
||||
var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
|
||||
await foreach (var streamingMessage in streamingReply)
|
||||
{
|
||||
@@ -123,7 +123,7 @@ public partial class OpenAICodeSnippet
|
||||
{ functions.GetWeatherFunctionContract.Name, functions.GetWeatherWrapper } // GetWeatherWrapper is a wrapper function for GetWeather, which is also auto-generated
|
||||
});
|
||||
|
||||
openAIChatAgent = openAIChatAgent.RegisterMiddleware(functionCallMiddleware);
|
||||
openAIChatAgent = openAIChatAgent.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
#endregion create_function_call_middleware
|
||||
|
||||
#region chat_agent_send_function_call
|
||||
|
||||
@@ -49,7 +49,7 @@ public class SemanticKernelCodeSnippet
|
||||
#endregion create_semantic_kernel_agent
|
||||
|
||||
#region create_semantic_kernel_agent_streaming
|
||||
var streamingReply = await semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
var streamingReply = semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
|
||||
await foreach (var streamingMessage in streamingReply)
|
||||
{
|
||||
|
||||
@@ -18,16 +18,17 @@ public static class Example02_TwoAgent_MathChat
|
||||
var teacher = new AssistantAgent(
|
||||
name: "teacher",
|
||||
systemMessage: @"You are a teacher that create pre-school math question for student and check answer.
|
||||
If the answer is correct, you terminate conversation by saying [TERMINATE].
|
||||
If the answer is correct, you stop the conversation by saying [COMPLETE].
|
||||
If the answer is wrong, you ask student to fix it.",
|
||||
llmConfig: new ConversableAgentConfig
|
||||
{
|
||||
Temperature = 0,
|
||||
ConfigList = [gpt35],
|
||||
})
|
||||
.RegisterPostProcess(async (_, reply, _) =>
|
||||
.RegisterMiddleware(async (msgs, option, agent, _) =>
|
||||
{
|
||||
if (reply.GetContent()?.ToLower().Contains("terminate") is true)
|
||||
var reply = await agent.GenerateReplyAsync(msgs, option);
|
||||
if (reply.GetContent()?.ToLower().Contains("complete") is true)
|
||||
{
|
||||
return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From);
|
||||
}
|
||||
|
||||
@@ -85,26 +85,16 @@ public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci
|
||||
systemMessage: "You run dotnet code",
|
||||
defaultReply: "No code available.")
|
||||
.RegisterDotnetCodeBlockExectionHook(interactiveService: service)
|
||||
.RegisterReply(async (msgs, _) =>
|
||||
.RegisterMiddleware(async (msgs, option, agent, _) =>
|
||||
{
|
||||
if (msgs.Count() == 0)
|
||||
if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder"))
|
||||
{
|
||||
return new TextMessage(Role.Assistant, "No code available. Coder please write code");
|
||||
}
|
||||
|
||||
return null;
|
||||
})
|
||||
.RegisterPreProcess(async (msgs, _) =>
|
||||
{
|
||||
// retrieve the most recent message from coder
|
||||
var coderMsg = msgs.LastOrDefault(msg => msg.From == "coder");
|
||||
if (coderMsg is null)
|
||||
{
|
||||
return Enumerable.Empty<IMessage>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return new[] { coderMsg };
|
||||
var coderMsg = msgs.Last(msg => msg.From == "coder");
|
||||
return await agent.GenerateReplyAsync([coderMsg], option);
|
||||
}
|
||||
})
|
||||
.RegisterPrintMessage();
|
||||
@@ -122,8 +112,9 @@ public partial class Example07_Dynamic_GroupChat_Calculate_Fibonacci
|
||||
systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer",
|
||||
temperature: 0,
|
||||
config: gpt3Config)
|
||||
.RegisterPostProcess(async (_, reply, _) =>
|
||||
.RegisterMiddleware(async (msgs, option, agent, _) =>
|
||||
{
|
||||
var reply = await agent.GenerateReplyAsync(msgs, option);
|
||||
if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true)
|
||||
{
|
||||
var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}";
|
||||
|
||||
@@ -62,7 +62,7 @@ public class Example10_SemanticKernel
|
||||
Console.WriteLine((reply as IMessage<ChatMessageContent>).Content.Items[0].As<TextContent>().Text);
|
||||
|
||||
var skAgentWithMiddleware = skAgent
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMessageConnector() // Register the message connector to support more AutoGen built-in message types
|
||||
.RegisterPrintMessage();
|
||||
|
||||
// Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage
|
||||
|
||||
@@ -28,7 +28,8 @@ public class Example13_OpenAIAgent_JsonMode
|
||||
systemMessage: "You are a helpful assistant designed to output JSON.",
|
||||
seed: 0, // explicitly set a seed to enable deterministic output
|
||||
responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode
|
||||
.RegisterMessageConnector();
|
||||
.RegisterMessageConnector()
|
||||
.RegisterPrintMessage();
|
||||
#endregion create_agent
|
||||
|
||||
#region chat_with_agent
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Program.cs
|
||||
|
||||
using AutoGen.BasicSample;
|
||||
await Example14_MistralClientAgent_TokenCount.RunAsync();
|
||||
await Example02_TwoAgent_MathChat.RunAsync();
|
||||
|
||||
@@ -23,7 +23,7 @@ public interface IMiddlewareAgent : IAgent
|
||||
void Use(IMiddleware middleware);
|
||||
}
|
||||
|
||||
public interface IMiddlewareStreamAgent : IMiddlewareAgent, IStreamingAgent
|
||||
public interface IMiddlewareStreamAgent : IStreamingAgent
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the inner agent.
|
||||
@@ -44,7 +44,11 @@ public interface IMiddlewareAgent<out T> : IMiddlewareAgent
|
||||
T TAgent { get; }
|
||||
}
|
||||
|
||||
public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent, IMiddlewareAgent<T>
|
||||
public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent
|
||||
where T : IStreamingAgent
|
||||
{
|
||||
/// <summary>
|
||||
/// Get the typed inner agent.
|
||||
/// </summary>
|
||||
T TStreamingAgent { get; }
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AutoGen.Core;
|
||||
|
||||
@@ -12,7 +11,7 @@ namespace AutoGen.Core;
|
||||
/// </summary>
|
||||
public interface IStreamingAgent : IAgent
|
||||
{
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
|
||||
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace AutoGen.Core;
|
||||
/// </summary>
|
||||
public class MiddlewareAgent : IMiddlewareAgent
|
||||
{
|
||||
private readonly IAgent _agent;
|
||||
private IAgent _agent;
|
||||
private readonly List<IMiddleware> middlewares = new();
|
||||
|
||||
/// <summary>
|
||||
@@ -22,10 +22,17 @@ public class MiddlewareAgent : IMiddlewareAgent
|
||||
/// </summary>
|
||||
/// <param name="innerAgent">the inner agent where middleware will be added.</param>
|
||||
/// <param name="name">the name of the agent if provided. Otherwise, the name of <paramref name="innerAgent"/> will be used.</param>
|
||||
public MiddlewareAgent(IAgent innerAgent, string? name = null)
|
||||
public MiddlewareAgent(IAgent innerAgent, string? name = null, IEnumerable<IMiddleware>? middlewares = null)
|
||||
{
|
||||
this.Name = name ?? innerAgent.Name;
|
||||
this._agent = innerAgent;
|
||||
if (middlewares != null && middlewares.Any())
|
||||
{
|
||||
foreach (var middleware in middlewares)
|
||||
{
|
||||
this.Use(middleware);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -55,13 +62,7 @@ public class MiddlewareAgent : IMiddlewareAgent
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
IAgent agent = this._agent;
|
||||
foreach (var middleware in this.middlewares)
|
||||
{
|
||||
agent = new DelegateAgent(middleware, agent);
|
||||
}
|
||||
|
||||
return agent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -71,15 +72,18 @@ public class MiddlewareAgent : IMiddlewareAgent
|
||||
/// </summary>
|
||||
public void Use(Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func, string? middlewareName = null)
|
||||
{
|
||||
this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
|
||||
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
|
||||
{
|
||||
return await func(context.Messages, context.Options, agent, cancellationToken);
|
||||
}));
|
||||
});
|
||||
|
||||
this.Use(middleware);
|
||||
}
|
||||
|
||||
public void Use(IMiddleware middleware)
|
||||
{
|
||||
this.middlewares.Add(middleware);
|
||||
_agent = new DelegateAgent(middleware, _agent);
|
||||
}
|
||||
|
||||
public override string ToString()
|
||||
|
||||
@@ -2,33 +2,31 @@
|
||||
// MiddlewareStreamingAgent.cs
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AutoGen.Core;
|
||||
|
||||
public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
|
||||
public class MiddlewareStreamingAgent : IMiddlewareStreamAgent
|
||||
{
|
||||
private readonly IStreamingAgent _agent;
|
||||
private IStreamingAgent _agent;
|
||||
private readonly List<IStreamingMiddleware> _streamingMiddlewares = new();
|
||||
private readonly List<IMiddleware> _middlewares = new();
|
||||
|
||||
public MiddlewareStreamingAgent(
|
||||
IStreamingAgent agent,
|
||||
string? name = null,
|
||||
IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null,
|
||||
IEnumerable<IMiddleware>? middlewares = null)
|
||||
: base(agent, name)
|
||||
IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null)
|
||||
{
|
||||
this.Name = name ?? agent.Name;
|
||||
_agent = agent;
|
||||
if (streamingMiddlewares != null)
|
||||
{
|
||||
_streamingMiddlewares.AddRange(streamingMiddlewares);
|
||||
}
|
||||
|
||||
if (middlewares != null)
|
||||
if (streamingMiddlewares != null && streamingMiddlewares.Any())
|
||||
{
|
||||
_middlewares.AddRange(middlewares);
|
||||
foreach (var middleware in streamingMiddlewares)
|
||||
{
|
||||
this.UseStreaming(middleware);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,26 +40,28 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
|
||||
/// </summary>
|
||||
public IEnumerable<IStreamingMiddleware> StreamingMiddlewares => _streamingMiddlewares;
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
var agent = _agent;
|
||||
foreach (var middleware in _streamingMiddlewares)
|
||||
{
|
||||
agent = new DelegateStreamingAgent(middleware, agent);
|
||||
}
|
||||
public string Name { get; }
|
||||
|
||||
return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
|
||||
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return _agent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
|
||||
return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
|
||||
public void UseStreaming(IStreamingMiddleware middleware)
|
||||
{
|
||||
_streamingMiddlewares.Add(middleware);
|
||||
_agent = new DelegateStreamingAgent(middleware, _agent);
|
||||
}
|
||||
|
||||
private class DelegateStreamingAgent : IStreamingAgent
|
||||
{
|
||||
private IStreamingMiddleware? streamingMiddleware;
|
||||
private IMiddleware? middleware;
|
||||
private IStreamingAgent innerAgent;
|
||||
|
||||
public string Name => innerAgent.Name;
|
||||
@@ -72,24 +72,19 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
|
||||
this.innerAgent = next;
|
||||
}
|
||||
|
||||
public DelegateStreamingAgent(IMiddleware middleware, IStreamingAgent next)
|
||||
{
|
||||
this.middleware = middleware;
|
||||
this.innerAgent = next;
|
||||
}
|
||||
|
||||
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (middleware is null)
|
||||
if (this.streamingMiddleware is null)
|
||||
{
|
||||
return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
return innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
|
||||
var context = new MiddlewareContext(messages, options);
|
||||
return await middleware.InvokeAsync(context, innerAgent, cancellationToken);
|
||||
return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken);
|
||||
}
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (streamingMiddleware is null)
|
||||
{
|
||||
@@ -105,20 +100,20 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
|
||||
public sealed class MiddlewareStreamingAgent<T> : MiddlewareStreamingAgent, IMiddlewareStreamAgent<T>
|
||||
where T : IStreamingAgent
|
||||
{
|
||||
public MiddlewareStreamingAgent(T innerAgent, string? name = null)
|
||||
: base(innerAgent, name)
|
||||
public MiddlewareStreamingAgent(T innerAgent, string? name = null, IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null)
|
||||
: base(innerAgent, name, streamingMiddlewares)
|
||||
{
|
||||
TAgent = innerAgent;
|
||||
TStreamingAgent = innerAgent;
|
||||
}
|
||||
|
||||
public MiddlewareStreamingAgent(MiddlewareStreamingAgent<T> other)
|
||||
: base(other)
|
||||
{
|
||||
TAgent = other.TAgent;
|
||||
TStreamingAgent = other.TStreamingAgent;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Get the inner agent.
|
||||
/// </summary>
|
||||
public T TAgent { get; }
|
||||
public T TStreamingAgent { get; }
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ public static class MiddlewareExtension
|
||||
/// <param name="replyFunc"></param>
|
||||
/// <returns></returns>
|
||||
/// <exception cref="Exception">throw when agent name is null.</exception>
|
||||
[Obsolete("Use RegisterMiddleware instead.")]
|
||||
public static MiddlewareAgent<TAgent> RegisterReply<TAgent>(
|
||||
this TAgent agent,
|
||||
Func<IEnumerable<IMessage>, CancellationToken, Task<IMessage?>> replyFunc)
|
||||
@@ -45,6 +46,7 @@ public static class MiddlewareExtension
|
||||
/// One example is <see cref="PrintMessageMiddlewareExtension.RegisterPrintMessage{TAgent}(TAgent)" />, which print the formatted message to console before the agent return the reply.
|
||||
/// </summary>
|
||||
/// <exception cref="Exception">throw when agent name is null.</exception>
|
||||
[Obsolete("Use RegisterMiddleware instead.")]
|
||||
public static MiddlewareAgent<TAgent> RegisterPostProcess<TAgent>(
|
||||
this TAgent agent,
|
||||
Func<IEnumerable<IMessage>, IMessage, CancellationToken, Task<IMessage>> postprocessFunc)
|
||||
@@ -62,6 +64,7 @@ public static class MiddlewareExtension
|
||||
/// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply.
|
||||
/// </summary>
|
||||
/// <exception cref="Exception">throw when agent name is null.</exception>
|
||||
[Obsolete("Use RegisterMiddleware instead.")]
|
||||
public static MiddlewareAgent<TAgent> RegisterPreProcess<TAgent>(
|
||||
this TAgent agent,
|
||||
Func<IEnumerable<IMessage>, CancellationToken, Task<IEnumerable<IMessage>>> preprocessFunc)
|
||||
@@ -77,6 +80,7 @@ public static class MiddlewareExtension
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
|
||||
/// </summary>
|
||||
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
|
||||
this TAgent agent,
|
||||
@@ -94,6 +98,7 @@ public static class MiddlewareExtension
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
|
||||
/// </summary>
|
||||
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
|
||||
this TAgent agent,
|
||||
@@ -107,6 +112,7 @@ public static class MiddlewareExtension
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
|
||||
/// </summary>
|
||||
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
|
||||
this MiddlewareAgent<TAgent> agent,
|
||||
@@ -124,6 +130,7 @@ public static class MiddlewareExtension
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// To register a streaming middleware, use <see cref="StreamingMiddlewareExtension.RegisterStreamingMiddleware{TAgent}(MiddlewareStreamingAgent{TAgent}, IStreamingMiddleware)"/>.
|
||||
/// </summary>
|
||||
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
|
||||
this MiddlewareAgent<TAgent> agent,
|
||||
|
||||
@@ -62,7 +62,7 @@ public static class PrintMessageMiddlewareExtension
|
||||
{
|
||||
var middleware = new PrintMessageMiddleware();
|
||||
var middlewareAgent = new MiddlewareStreamingAgent<TAgent>(agent);
|
||||
middlewareAgent.Use(middleware);
|
||||
middlewareAgent.UseStreaming(middleware);
|
||||
|
||||
return middlewareAgent;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// StreamingMiddlewareExtension.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AutoGen.Core;
|
||||
|
||||
public static class StreamingMiddlewareExtension
|
||||
{
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// Register an <see cref="IStreamingMiddleware"/> to an existing <see cref="IStreamingAgent"/> and return a new agent with the registered middleware.
|
||||
/// For registering an <see cref="IMiddleware"/>, please refer to <see cref="MiddlewareExtension.RegisterMiddleware{TAgent}(MiddlewareAgent{TAgent}, IMiddleware)"/>
|
||||
/// </summary>
|
||||
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterStreamingMiddleware<TStreamingAgent>(
|
||||
this TStreamingAgent agent,
|
||||
@@ -21,16 +17,12 @@ public static class StreamingMiddlewareExtension
|
||||
var middlewareAgent = new MiddlewareStreamingAgent<TStreamingAgent>(agent);
|
||||
middlewareAgent.UseStreaming(middleware);
|
||||
|
||||
if (middleware is IMiddleware middlewareBase)
|
||||
{
|
||||
middlewareAgent.Use(middlewareBase);
|
||||
}
|
||||
|
||||
return middlewareAgent;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// Register an <see cref="IStreamingMiddleware"/> to an existing <see cref="IStreamingAgent"/> and return a new agent with the registered middleware.
|
||||
/// For registering an <see cref="IMiddleware"/>, please refer to <see cref="MiddlewareExtension.RegisterMiddleware{TAgent}(MiddlewareAgent{TAgent}, IMiddleware)"/>
|
||||
/// </summary>
|
||||
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
|
||||
this MiddlewareStreamingAgent<TAgent> agent,
|
||||
@@ -40,75 +32,6 @@ public static class StreamingMiddlewareExtension
|
||||
var copyAgent = new MiddlewareStreamingAgent<TAgent>(agent);
|
||||
copyAgent.UseStreaming(middleware);
|
||||
|
||||
if (middleware is IMiddleware middlewareBase)
|
||||
{
|
||||
copyAgent.Use(middlewareBase);
|
||||
}
|
||||
|
||||
return copyAgent;
|
||||
}
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing agent and return a new agent with the middleware.
|
||||
/// </summary>
|
||||
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
|
||||
this TAgent agent,
|
||||
Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task<IAsyncEnumerable<IStreamingMessage>>> func,
|
||||
string? middlewareName = null)
|
||||
where TAgent : IStreamingAgent
|
||||
{
|
||||
var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func));
|
||||
|
||||
return agent.RegisterStreamingMiddleware(middleware);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Register a streaming middleware to an existing agent and return a new agent with the middleware.
|
||||
/// </summary>
|
||||
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
|
||||
this MiddlewareStreamingAgent<TAgent> agent,
|
||||
Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task<IAsyncEnumerable<IStreamingMessage>>> func,
|
||||
string? middlewareName = null)
|
||||
where TAgent : IStreamingAgent
|
||||
{
|
||||
var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func));
|
||||
|
||||
return agent.RegisterStreamingMiddleware(middleware);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing streaming agent and return a new agent with the middleware.
|
||||
/// </summary>
|
||||
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterMiddleware<TStreamingAgent>(
|
||||
this MiddlewareStreamingAgent<TStreamingAgent> streamingAgent,
|
||||
Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func,
|
||||
string? middlewareName = null)
|
||||
where TStreamingAgent : IStreamingAgent
|
||||
{
|
||||
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
|
||||
{
|
||||
return await func(context.Messages, context.Options, agent, cancellationToken);
|
||||
});
|
||||
|
||||
return streamingAgent.RegisterMiddleware(middleware);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Register a middleware to an existing streaming agent and return a new agent with the middleware.
|
||||
/// </summary>
|
||||
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterMiddleware<TStreamingAgent>(
|
||||
this MiddlewareStreamingAgent<TStreamingAgent> streamingAgent,
|
||||
IMiddleware middleware)
|
||||
where TStreamingAgent : IStreamingAgent
|
||||
{
|
||||
var copyAgent = new MiddlewareStreamingAgent<TStreamingAgent>(streamingAgent);
|
||||
copyAgent.Use(middleware);
|
||||
if (middleware is IStreamingMiddleware streamingMiddleware)
|
||||
{
|
||||
copyAgent.UseStreaming(streamingMiddleware);
|
||||
}
|
||||
|
||||
return copyAgent;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// DelegateStreamingMiddleware.cs
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AutoGen.Core;
|
||||
|
||||
internal class DelegateStreamingMiddleware : IStreamingMiddleware
|
||||
{
|
||||
public delegate Task<IAsyncEnumerable<IStreamingMessage>> MiddlewareDelegate(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
CancellationToken cancellationToken);
|
||||
|
||||
private readonly MiddlewareDelegate middlewareDelegate;
|
||||
|
||||
public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDelegate)
|
||||
{
|
||||
this.Name = name;
|
||||
this.middlewareDelegate = middlewareDelegate;
|
||||
}
|
||||
|
||||
public string? Name { get; }
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var messages = context.Messages;
|
||||
var options = context.Options;
|
||||
|
||||
return this.middlewareDelegate(context, agent, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace AutoGen.Core;
|
||||
/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function.
|
||||
/// </para>
|
||||
/// </summary>
|
||||
public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
|
||||
public class FunctionCallMiddleware : IStreamingMiddleware
|
||||
{
|
||||
private readonly IEnumerable<FunctionContract>? functions;
|
||||
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
|
||||
@@ -71,15 +71,10 @@ public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
|
||||
return reply;
|
||||
}
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.FromResult(this.StreamingInvokeAsync(context, agent, cancellationToken));
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IStreamingMessage> StreamingInvokeAsync(
|
||||
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var lastMessage = context.Messages.Last();
|
||||
if (lastMessage is ToolCallMessage toolCallMessage)
|
||||
@@ -93,7 +88,7 @@ public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
|
||||
options.Functions = combinedFunctions?.ToArray();
|
||||
|
||||
IStreamingMessage? initMessage = default;
|
||||
await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
|
||||
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
|
||||
{
|
||||
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@ using System.Threading.Tasks;
|
||||
namespace AutoGen.Core;
|
||||
|
||||
/// <summary>
|
||||
/// The middleware interface
|
||||
/// The middleware interface. For streaming-version middleware, check <see cref="IStreamingMiddleware"/>.
|
||||
/// </summary>
|
||||
public interface IMiddleware
|
||||
{
|
||||
|
||||
@@ -3,18 +3,18 @@
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace AutoGen.Core;
|
||||
|
||||
/// <summary>
|
||||
/// The streaming middleware interface
|
||||
/// The streaming middleware interface. For non-streaming version middleware, check <see cref="IMiddleware"/>.
|
||||
/// </summary>
|
||||
public interface IStreamingMiddleware
|
||||
public interface IStreamingMiddleware : IMiddleware
|
||||
{
|
||||
public string? Name { get; }
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
|
||||
/// <summary>
|
||||
/// The streaming version of <see cref="IMiddleware.InvokeAsync(MiddlewareContext, IAgent, CancellationToken)"/>.
|
||||
/// </summary>
|
||||
public IAsyncEnumerable<IStreamingMessage> InvokeAsync(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
CancellationToken cancellationToken = default);
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
// PrintMessageMiddleware.cs
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
@@ -10,7 +12,7 @@ namespace AutoGen.Core;
|
||||
/// <summary>
|
||||
/// The middleware that prints the reply from agent to the console.
|
||||
/// </summary>
|
||||
public class PrintMessageMiddleware : IMiddleware
|
||||
public class PrintMessageMiddleware : IStreamingMiddleware
|
||||
{
|
||||
public string? Name => nameof(PrintMessageMiddleware);
|
||||
|
||||
@@ -19,51 +21,12 @@ public class PrintMessageMiddleware : IMiddleware
|
||||
if (agent is IStreamingAgent streamingAgent)
|
||||
{
|
||||
IMessage? recentUpdate = null;
|
||||
await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
|
||||
await foreach (var message in this.InvokeAsync(context, streamingAgent, cancellationToken))
|
||||
{
|
||||
if (message is TextMessageUpdate textMessageUpdate)
|
||||
{
|
||||
if (recentUpdate is null)
|
||||
{
|
||||
// Print from: xxx
|
||||
Console.WriteLine($"from: {textMessageUpdate.From}");
|
||||
recentUpdate = new TextMessage(textMessageUpdate);
|
||||
Console.Write(textMessageUpdate.Content);
|
||||
}
|
||||
else if (recentUpdate is TextMessage recentTextMessage)
|
||||
{
|
||||
// Print the content of the message
|
||||
Console.Write(textMessageUpdate.Content);
|
||||
recentTextMessage.Update(textMessageUpdate);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The recent update is not a TextMessage");
|
||||
}
|
||||
}
|
||||
else if (message is ToolCallMessageUpdate toolCallUpdate)
|
||||
{
|
||||
if (recentUpdate is null)
|
||||
{
|
||||
recentUpdate = new ToolCallMessage(toolCallUpdate);
|
||||
}
|
||||
else if (recentUpdate is ToolCallMessage recentToolCallMessage)
|
||||
{
|
||||
recentToolCallMessage.Update(toolCallUpdate);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The recent update is not a ToolCallMessage");
|
||||
}
|
||||
}
|
||||
else if (message is IMessage imessage)
|
||||
if (message is IMessage imessage)
|
||||
{
|
||||
recentUpdate = imessage;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The message is not a valid message");
|
||||
}
|
||||
}
|
||||
Console.WriteLine();
|
||||
if (recentUpdate is not null && recentUpdate is not TextMessage)
|
||||
@@ -84,4 +47,72 @@ public class PrintMessageMiddleware : IMiddleware
|
||||
return reply;
|
||||
}
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
IMessage? recentUpdate = null;
|
||||
await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
|
||||
{
|
||||
if (message is TextMessageUpdate textMessageUpdate)
|
||||
{
|
||||
if (recentUpdate is null)
|
||||
{
|
||||
// Print from: xxx
|
||||
Console.WriteLine($"from: {textMessageUpdate.From}");
|
||||
recentUpdate = new TextMessage(textMessageUpdate);
|
||||
Console.Write(textMessageUpdate.Content);
|
||||
|
||||
yield return message;
|
||||
}
|
||||
else if (recentUpdate is TextMessage recentTextMessage)
|
||||
{
|
||||
// Print the content of the message
|
||||
Console.Write(textMessageUpdate.Content);
|
||||
recentTextMessage.Update(textMessageUpdate);
|
||||
|
||||
yield return recentTextMessage;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The recent update is not a TextMessage");
|
||||
}
|
||||
}
|
||||
else if (message is ToolCallMessageUpdate toolCallUpdate)
|
||||
{
|
||||
if (recentUpdate is null)
|
||||
{
|
||||
recentUpdate = new ToolCallMessage(toolCallUpdate);
|
||||
|
||||
yield return message;
|
||||
}
|
||||
else if (recentUpdate is ToolCallMessage recentToolCallMessage)
|
||||
{
|
||||
recentToolCallMessage.Update(toolCallUpdate);
|
||||
|
||||
yield return message;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The recent update is not a ToolCallMessage");
|
||||
}
|
||||
}
|
||||
else if (message is IMessage imessage)
|
||||
{
|
||||
recentUpdate = imessage;
|
||||
|
||||
yield return imessage;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("The message is not a valid message");
|
||||
}
|
||||
}
|
||||
Console.WriteLine();
|
||||
if (recentUpdate is not null && recentUpdate is not TextMessage)
|
||||
{
|
||||
Console.WriteLine(recentUpdate.FormatMessage());
|
||||
}
|
||||
|
||||
yield return recentUpdate ?? throw new InvalidOperationException("The message is not a valid message");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,19 +28,19 @@ public static class AgentExtension
|
||||
string codeBlockSuffix = "```",
|
||||
int maximumOutputToKeep = 500)
|
||||
{
|
||||
return agent.RegisterReply(async (msgs, ct) =>
|
||||
return agent.RegisterMiddleware(async (msgs, option, innerAgent, ct) =>
|
||||
{
|
||||
var lastMessage = msgs.LastOrDefault();
|
||||
if (lastMessage == null || lastMessage.GetContent() is null)
|
||||
{
|
||||
return null;
|
||||
return await innerAgent.GenerateReplyAsync(msgs, option, ct);
|
||||
}
|
||||
|
||||
// retrieve all code blocks from last message
|
||||
var codeBlocks = lastMessage.GetContent()!.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries);
|
||||
if (codeBlocks.Length <= 0)
|
||||
{
|
||||
return null;
|
||||
return await innerAgent.GenerateReplyAsync(msgs, option, ct);
|
||||
}
|
||||
|
||||
// run code blocks
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using AutoGen.Core;
|
||||
@@ -77,19 +78,14 @@ public class MistralClientAgent : IStreamingAgent
|
||||
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
|
||||
}
|
||||
|
||||
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
|
||||
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var request = BuildChatRequest(messages, options);
|
||||
var response = _client.StreamingChatCompletionsAsync(request);
|
||||
|
||||
return ProcessMessage(response);
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IMessage> ProcessMessage(IAsyncEnumerable<ChatCompletionResponse> response)
|
||||
{
|
||||
await foreach (var content in response)
|
||||
{
|
||||
yield return new MessageEnvelope<ChatCompletionResponse>(content, from: this.Name);
|
||||
|
||||
@@ -18,9 +18,7 @@ public static class MistralAgentExtension
|
||||
connector = new MistralChatMessageConnector();
|
||||
}
|
||||
|
||||
return agent.RegisterStreamingMiddleware(connector)
|
||||
.RegisterMiddleware(connector);
|
||||
|
||||
return agent.RegisterStreamingMiddleware(connector);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -34,7 +32,6 @@ public static class MistralAgentExtension
|
||||
connector = new MistralChatMessageConnector();
|
||||
}
|
||||
|
||||
return agent.RegisterStreamingMiddleware(connector)
|
||||
.RegisterMiddleware(connector);
|
||||
return agent.RegisterStreamingMiddleware(connector);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,17 +15,12 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
|
||||
{
|
||||
public string? Name => nameof(MistralChatMessageConnector);
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.FromResult(StreamingInvoke(context, agent, cancellationToken));
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IStreamingMessage> StreamingInvoke(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var messages = context.Messages;
|
||||
var chatMessages = ProcessMessage(messages, agent);
|
||||
var chunks = new List<ChatCompletionResponse>();
|
||||
await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
|
||||
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
|
||||
{
|
||||
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
|
||||
{
|
||||
|
||||
@@ -90,30 +90,28 @@ public class GPTAgent : IStreamingAgent
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var agent = this._innerAgent
|
||||
.RegisterMessageConnector();
|
||||
if (this.functionMap is not null)
|
||||
{
|
||||
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
|
||||
agent = agent.RegisterMiddleware(functionMapMiddleware);
|
||||
}
|
||||
|
||||
return await agent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
|
||||
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var agent = this._innerAgent
|
||||
.RegisterMessageConnector();
|
||||
var agent = this._innerAgent.RegisterMessageConnector();
|
||||
if (this.functionMap is not null)
|
||||
{
|
||||
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
|
||||
agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
|
||||
}
|
||||
|
||||
return await agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
|
||||
return await agent.GenerateReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
var agent = this._innerAgent.RegisterMessageConnector();
|
||||
if (this.functionMap is not null)
|
||||
{
|
||||
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
|
||||
agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
|
||||
}
|
||||
|
||||
return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,15 +87,7 @@ public class OpenAIChatAgent : IStreamingAgent
|
||||
return new MessageEnvelope<ChatCompletions>(reply, from: this.Name);
|
||||
}
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.FromResult(this.StreamingReplyAsync(messages, options, cancellationToken));
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IStreamingMessage> StreamingReplyAsync(
|
||||
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
|
||||
@@ -44,22 +44,14 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
||||
return PostProcessMessage(reply);
|
||||
}
|
||||
|
||||
public async Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
|
||||
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
CancellationToken cancellationToken = default)
|
||||
{
|
||||
return InvokeStreamingAsync(context, agent, cancellationToken);
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IStreamingMessage> InvokeStreamingAsync(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var chatMessages = ProcessIncomingMessages(agent, context.Messages)
|
||||
.Select(m => new MessageEnvelope<ChatRequestMessage>(m));
|
||||
var streamingReply = await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
|
||||
var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
|
||||
string? currentToolName = null;
|
||||
await foreach (var reply in streamingReply)
|
||||
{
|
||||
@@ -135,6 +127,12 @@ public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddlewa
|
||||
|
||||
private IMessage PostProcessMessage(IMessage<ChatCompletions> message)
|
||||
{
|
||||
// throw exception if prompt filter results is not null
|
||||
if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered)
|
||||
{
|
||||
throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input.");
|
||||
}
|
||||
|
||||
return PostProcessMessage(message.Content.Choices[0].Message, message.From);
|
||||
}
|
||||
|
||||
|
||||
@@ -47,20 +47,12 @@ public class SemanticKernelChatMessageContentConnector : IMiddleware, IStreaming
|
||||
return PostProcessMessage(reply);
|
||||
}
|
||||
|
||||
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return Task.FromResult(InvokeStreamingAsync(context, agent, cancellationToken));
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IStreamingMessage> InvokeStreamingAsync(
|
||||
MiddlewareContext context,
|
||||
IStreamingAgent agent,
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken)
|
||||
public async IAsyncEnumerable<IStreamingMessage> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var chatMessageContents = ProcessMessage(context.Messages, agent)
|
||||
.Select(m => new MessageEnvelope<ChatMessageContent>(m));
|
||||
|
||||
await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken))
|
||||
await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken))
|
||||
{
|
||||
yield return PostProcessStreamingMessage(reply);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.SemanticKernel;
|
||||
@@ -64,17 +65,25 @@ public class SemanticKernelAgent : IStreamingAgent
|
||||
return new MessageEnvelope<ChatMessageContent>(reply.First(), from: this.Name);
|
||||
}
|
||||
|
||||
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
|
||||
public async IAsyncEnumerable<IStreamingMessage> GenerateStreamingReplyAsync(
|
||||
IEnumerable<IMessage> messages,
|
||||
GenerateReplyOptions? options = null,
|
||||
CancellationToken cancellationToken = default)
|
||||
[EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||
{
|
||||
var chatHistory = BuildChatHistory(messages);
|
||||
var option = BuildOption(options);
|
||||
var chatService = _kernel.GetRequiredService<IChatCompletionService>();
|
||||
var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken);
|
||||
|
||||
return ProcessMessage(response);
|
||||
await foreach (var content in response)
|
||||
{
|
||||
if (content.ChoiceIndex > 0)
|
||||
{
|
||||
throw new InvalidOperationException("Only one choice is supported in streaming response");
|
||||
}
|
||||
|
||||
yield return new MessageEnvelope<StreamingChatMessageContent>(content, from: this.Name);
|
||||
}
|
||||
}
|
||||
|
||||
private ChatHistory BuildChatHistory(IEnumerable<IMessage> messages)
|
||||
@@ -101,19 +110,6 @@ public class SemanticKernelAgent : IStreamingAgent
|
||||
};
|
||||
}
|
||||
|
||||
private async IAsyncEnumerable<IMessage> ProcessMessage(IAsyncEnumerable<StreamingChatMessageContent> response)
|
||||
{
|
||||
await foreach (var content in response)
|
||||
{
|
||||
if (content.ChoiceIndex > 0)
|
||||
{
|
||||
throw new InvalidOperationException("Only one choice is supported in streaming response");
|
||||
}
|
||||
|
||||
yield return new MessageEnvelope<StreamingChatMessageContent>(content, from: this.Name);
|
||||
}
|
||||
}
|
||||
|
||||
private IEnumerable<ChatMessageContent> ProcessMessage(IEnumerable<IMessage> messages)
|
||||
{
|
||||
return messages.Select(m => m switch
|
||||
|
||||
@@ -79,19 +79,33 @@ public class ConversableAgent : IAgent
|
||||
IAgent? agent = null;
|
||||
foreach (var llmConfig in config.ConfigList ?? Enumerable.Empty<ILLMConfig>())
|
||||
{
|
||||
agent = agent switch
|
||||
var nextAgent = llmConfig switch
|
||||
{
|
||||
null => llmConfig switch
|
||||
{
|
||||
AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0),
|
||||
OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0),
|
||||
_ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"),
|
||||
},
|
||||
IAgent innerAgent => innerAgent.RegisterReply(async (messages, cancellationToken) =>
|
||||
{
|
||||
return await innerAgent.GenerateReplyAsync(messages, cancellationToken: cancellationToken);
|
||||
}),
|
||||
AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0),
|
||||
OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0),
|
||||
_ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"),
|
||||
};
|
||||
|
||||
if (agent == null)
|
||||
{
|
||||
agent = nextAgent;
|
||||
}
|
||||
else
|
||||
{
|
||||
agent = agent.RegisterMiddleware(async (messages, option, agent, cancellationToken) =>
|
||||
{
|
||||
var agentResponse = await nextAgent.GenerateReplyAsync(messages, option, cancellationToken: cancellationToken);
|
||||
|
||||
if (agentResponse is null)
|
||||
{
|
||||
return await agent.GenerateReplyAsync(messages, option, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
return agentResponse;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return agent;
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
<ProjectReference Include="..\AutoGen.Mistral\AutoGen.Mistral.csproj" />
|
||||
<ProjectReference Include="..\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj" />
|
||||
<ProjectReference Include="..\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
|
||||
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
||||
@@ -114,7 +114,7 @@ public partial class MistralClientAgentTests
|
||||
model: "mistral-small-latest",
|
||||
randomSeed: 0)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
|
||||
var functionCallMiddlewareExecutorMiddleware = new FunctionCallMiddleware(
|
||||
functionMap: new Dictionary<string, Func<string, Task<string>>>
|
||||
@@ -127,7 +127,7 @@ public partial class MistralClientAgentTests
|
||||
model: "mistral-small-latest",
|
||||
randomSeed: 0)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(functionCallMiddlewareExecutorMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddlewareExecutorMiddleware);
|
||||
await twoAgentTest.TwoAgentGetWeatherFunctionCallTestAsync(executorAgent, functionCallAgent);
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ public partial class MistralClientAgentTests
|
||||
model: "mistral-small-latest",
|
||||
randomSeed: 0)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
|
||||
var question = new TextMessage(Role.User, "what's the weather in Seattle?");
|
||||
var reply = await functionCallAgent.SendAsync(question);
|
||||
@@ -193,7 +193,7 @@ public partial class MistralClientAgentTests
|
||||
toolChoice: ToolChoiceEnum.Any,
|
||||
randomSeed: 0)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
await singleAgentTest.EchoFunctionCallExecutionTestAsync(agent);
|
||||
await singleAgentTest.EchoFunctionCallExecutionStreamingTestAsync(agent);
|
||||
}
|
||||
@@ -214,7 +214,7 @@ public partial class MistralClientAgentTests
|
||||
systemMessage: "You are a helpful assistant that can call functions",
|
||||
randomSeed: 0)
|
||||
.RegisterMessageConnector()
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
await singleAgentTest.EchoFunctionCallTestAsync(agent);
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ public partial class MistralClientAgentTests
|
||||
var question = new TextMessage(Role.User, "what's the weather in Seattle?");
|
||||
IMessage? finalReply = null;
|
||||
|
||||
await foreach (var reply in await agent.GenerateStreamingReplyAsync([question]))
|
||||
await foreach (var reply in agent.GenerateStreamingReplyAsync([question]))
|
||||
{
|
||||
reply.From.Should().Be(agent.Name);
|
||||
if (reply is IMessage message)
|
||||
|
||||
@@ -47,7 +47,7 @@ public partial class OpenAIChatAgentTest
|
||||
reply.As<MessageEnvelope<ChatCompletions>>().Content.Usage.TotalTokens.Should().BeGreaterThan(0);
|
||||
|
||||
// test streaming
|
||||
var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
|
||||
await foreach (var streamingMessage in streamingReply)
|
||||
{
|
||||
@@ -93,7 +93,7 @@ public partial class OpenAIChatAgentTest
|
||||
// test streaming
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = await assistant.GenerateStreamingReplyAsync([message]);
|
||||
var reply = assistant.GenerateStreamingReplyAsync([message]);
|
||||
|
||||
await foreach (var streamingMessage in reply)
|
||||
{
|
||||
@@ -119,10 +119,9 @@ public partial class OpenAIChatAgentTest
|
||||
MiddlewareStreamingAgent<OpenAIChatAgent> assistant = openAIChatAgent
|
||||
.RegisterMessageConnector();
|
||||
|
||||
assistant.Middlewares.Count().Should().Be(1);
|
||||
assistant.StreamingMiddlewares.Count().Should().Be(1);
|
||||
var functionCallAgent = assistant
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
|
||||
var question = "What's the weather in Seattle";
|
||||
var messages = new IMessage[]
|
||||
@@ -150,7 +149,7 @@ public partial class OpenAIChatAgentTest
|
||||
// test streaming
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]);
|
||||
var reply = functionCallAgent.GenerateStreamingReplyAsync([message]);
|
||||
ToolCallMessage? toolCallMessage = null;
|
||||
await foreach (var streamingMessage in reply)
|
||||
{
|
||||
@@ -191,7 +190,7 @@ public partial class OpenAIChatAgentTest
|
||||
.RegisterMessageConnector();
|
||||
|
||||
var functionCallAgent = assistant
|
||||
.RegisterMiddleware(functionCallMiddleware);
|
||||
.RegisterStreamingMiddleware(functionCallMiddleware);
|
||||
|
||||
var question = "What's the weather in Seattle";
|
||||
var messages = new IMessage[]
|
||||
@@ -220,7 +219,7 @@ public partial class OpenAIChatAgentTest
|
||||
// test streaming
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]);
|
||||
var reply = functionCallAgent.GenerateStreamingReplyAsync([message]);
|
||||
await foreach (var streamingMessage in reply)
|
||||
{
|
||||
if (streamingMessage is not IMessage)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// RegisterReplyAgentTest.cs
|
||||
|
||||
using System.Threading.Tasks;
|
||||
using FluentAssertions;
|
||||
using Xunit;
|
||||
|
||||
namespace AutoGen.Tests
|
||||
{
|
||||
public class RegisterReplyAgentTest
|
||||
{
|
||||
[Fact]
|
||||
public async Task RegisterReplyTestAsync()
|
||||
{
|
||||
IAgent echoAgent = new EchoAgent("echo");
|
||||
echoAgent = echoAgent
|
||||
.RegisterReply(async (conversations, ct) => new TextMessage(Role.Assistant, "I'm your father", from: echoAgent.Name));
|
||||
|
||||
var msg = new Message(Role.User, "hey");
|
||||
var reply = await echoAgent.SendAsync(msg);
|
||||
reply.Should().BeOfType<TextMessage>();
|
||||
reply.GetContent().Should().Be("I'm your father");
|
||||
reply.GetRole().Should().Be(Role.Assistant);
|
||||
reply.From.Should().Be("echo");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,7 @@ public partial class SemanticKernelAgentTest
|
||||
reply.As<MessageEnvelope<ChatMessageContent>>().From.Should().Be("assistant");
|
||||
|
||||
// test streaming
|
||||
var streamingReply = await skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
var streamingReply = skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
|
||||
|
||||
await foreach (var streamingMessage in streamingReply)
|
||||
{
|
||||
@@ -63,10 +63,8 @@ public partial class SemanticKernelAgentTest
|
||||
|
||||
var kernel = builder.Build();
|
||||
|
||||
var connector = new SemanticKernelChatMessageContentConnector();
|
||||
var skAgent = new SemanticKernelAgent(kernel, "assistant")
|
||||
.RegisterStreamingMiddleware(connector)
|
||||
.RegisterMiddleware(connector);
|
||||
.RegisterMessageConnector();
|
||||
|
||||
var messages = new IMessage[]
|
||||
{
|
||||
@@ -90,7 +88,7 @@ public partial class SemanticKernelAgentTest
|
||||
// test streaming
|
||||
foreach (var message in messages)
|
||||
{
|
||||
var reply = await skAgent.GenerateStreamingReplyAsync([message]);
|
||||
var reply = skAgent.GenerateStreamingReplyAsync([message]);
|
||||
|
||||
await foreach (var streamingMessage in reply)
|
||||
{
|
||||
@@ -122,7 +120,6 @@ public partial class SemanticKernelAgentTest
|
||||
var skAgent = new SemanticKernelAgent(kernel, "assistant")
|
||||
.RegisterMessageConnector();
|
||||
|
||||
skAgent.Middlewares.Count().Should().Be(1);
|
||||
skAgent.StreamingMiddlewares.Count().Should().Be(1);
|
||||
|
||||
var question = "What is the weather in Seattle?";
|
||||
|
||||
@@ -261,7 +261,7 @@ namespace AutoGen.Tests
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
|
||||
var answer = "[ECHO] Hello world";
|
||||
IStreamingMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
@@ -302,7 +302,7 @@ namespace AutoGen.Tests
|
||||
{
|
||||
Temperature = 0,
|
||||
};
|
||||
var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
|
||||
var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option);
|
||||
var answer = "A B C D E F G H I J K L M N";
|
||||
TextMessage? finalReply = default;
|
||||
await foreach (var reply in replyStream)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
##### Update
|
||||
- [API Breaking Change] Update the return type of `IStreamingAgent.GenerateStreamingReplyAsync` from `Task<IAsyncEnumerable<IStreamingMessage>>` to `IAsyncEnumerable<IStreamingMessage>`
|
||||
- [API Breaking Change] Update the return type of `IStreamingMiddleware.InvokeAsync` from `Task<IAsyncEnumerable<IStreamingMessage>>` to `IAsyncEnumerable<IStreamingMessage>`
|
||||
- [API Breaking Change] Mark `RegisterReply`, `RegisterPreProcess` and `RegisterPostProcess` as obsolete. You can replace them with `RegisterMiddleware`
|
||||
##### Update on 0.0.12 (2024-04-22)
|
||||
- Add AutoGen.Mistral package to support Mistral.AI models
|
||||
|
||||
##### Update on 0.0.11 (2024-04-10)
|
||||
- Add link to Discord channel in nuget's readme.md
|
||||
- Document improvements
|
||||
|
||||
Reference in New Issue
Block a user