Bring Dotnet AutoGen (#924)

* update readme

* update

* update

* update

* update

* update

* update

* add sample project

* revert notebook change back

* update

* update interactive version

* add nuget package

* refactor Message

* update example

* add azure nightly build pipeline

* Set up CI with Azure Pipelines

[skip ci]

* Update nightly-build.yml for Azure Pipelines

* add dotnet interactive package

* add dotnet interactive package

* update pipeline

* add nuget feed back

* remove dotnet-tool feed

* remove dotnet-tool feed comment

* update pipeline

* update build name

* Update nightly-build.yml

* Delete .github/workflows/dotnet-ci.yml

* update

* add working_dir to use step

* add initateChat api

* update oai package

* Update dotnet-build.yml

* Update dotnet-run-openai-test-and-notebooks.yml

* update build workflow

* update build workflow

* update nuget feed

* update nuget feed

* update aoai and sk version

* Update InteractiveService.cs

* add support for GPT 4V

* add DalleAndGPT4V example

* update example

* add user proxy agent

* add readme

* bump version

* update example

* add dotnet interactive hook

* update

* udpate tests

* add website

* update index.md

* add docs

* update doc

* move sk dependency out of core package

* udpate doc

* Update Use-function-call.md

* add type safe function call document

* update doc

* update doc

* add dock

* Update Use-function-call.md

* add GenerateReplyOptions

* remove IChatLLM

* update version

* update doc

* update website

* add sample

* fix link

* add middleware agent

* clean up doc

* bump version

* update doc

* update

* add Other Language

* remove warnings

* add sign.props

* add sign step

* fix pipelien

* auth

* real sign

* disable PR trigger

* update

* disable PR trigger

* use microbuild machine

* update build pipeline to add publish to internal feed

* add internal feed

* fix build pipeline

* add dotnet prefix

* update ci

* add build number

* update run number

* update source

* update token

* update

* remove adding source

* add publish to github package

* try again

* try again

* ask for write pacakge

* disable package when branch is not main

* update

* implement streaming agent

* add test for streaming function call

* update

* fix #1588

* enable PR check for dotnet branch

* add website readme

* only publish to dotnet feed when pushing to dotnet branch

* remove openai-test-and-notebooks workflow

* update readme

* update readme

* update workflow

* update getting-start

* upgrade test and sample proejct to use .net 8

* fix global.json format && make loadFromConfig API internal only before implementing

* update

* add support for LM studio

* add doc

* Update README.md

* add push and workflow_dispatch trigger

* disable PR for main

* add dotnet env

* Update Installation.md

* add nuget

* refer to newtonsoft 13

* update branch to dotnet in docfx

* Update Installation.md

* pull out HumanInputMiddleware and FunctionCallMiddleware

* fix tests

* add link to sample folder

* refactor message

* refactor over IMessage

* add more tests

* add more test

* fix build error

* rename header

* add semantic kernel project

* update sk example

* update dotnet version

* add LMStudio function call example

* rename LLaMAFunctin

* remove dotnet run openai test and notebook workflow

* add FunctionContract and test

* update doc

* add documents

* add workflow

* update

* update sample

* fix warning in test

* reult length can be less then maximumOutputToKeep (#1804)

* merge with main

* add option to retrieve inner agent and middlewares from MiddlewareAgent

* update doc

* adjust namespace

* update readme

* fix test

* use IMessage

* more updates

* update

* fix test

* add comments

* use FunctionContract to replace FunctionDefinition

* move AutoGen contrac to AutoGen.Core

* update installation

* refactor streamingAgent by adding StreamingMessage type

* update sample

* update samples

* update

* update

* add test

* fix test

* bump version

* add openaichat test

* update

* Update Example03_Agent_FunctionCall.cs

* [.Net] improve docs (#1862)

* add doc

* add doc

* add doc

* add doc

* add doc

* add doc

* update

* fix test error

* fix some error

* fix test

* fix test

* add more tests

* edits

---------

Co-authored-by: ekzhu <ekzhu@users.noreply.github.com>

* [.Net] Add fill form example (#1911)

* add form filler example

* update

* fix ci error

* [.Net] Add using AutoGen.Core in source generator (#1983)

* fix using namespace bug in source generator

* remove using in sourcegenerator test

* disable PR test

* Add .idea to .gitignore (#1988)

* [.Net] publish to nuget.org feed (#1987)

* publish to nuget

* update ci

* update dotnet-release

* update release pipeline

* add source

* remove empty symbol package

* update pipeline

* remove tag

* update installation guide

* [.Net] Rename some classes && APIs based on doc review (#1980)

* rename sequential group chat to round robin group chat

* rename to sendInstruction

* rename workflow to graph

* rename some api

* bump version

* move Graph to GroupChat folder

* rename fill application example

* [.Net] Improve package description (#2161)

* add discord link and update package description

* Update getting-start.md

* [.Net] Fix document comment from the most recent AutoGen.Net engineer sync (#2231)

* update

* rename RegisterPrintMessageHook to RegisterPrintMessage

* update website

* update update.md

* fix link error

* [.Net] Enable JsonMode and deterministic output in AutoGen.OpenAI OpenAIChatAgent (#2347)

* update openai version && add sample for json output

* add example in web

* update update.md

* update image url

* [.Net] Add AutoGen.Mistral package (#2330)

* add mstral client

* enable streaming support

* add mistralClientAgent

* add test for function call

* add extension

* add support for toolcall and toolcall result message

* add support for aggregate message

* implement streaming function call

* track (#2471)

* [.Net] add mistral example (#2482)

* update existing examples to use messageCOnnector

* add overview

* add function call document

* add example 14

* add mistral token count usage example

* update version

* Update dotnet-release.yml (#2488)

* update

* revert gitattributes

---------

Co-authored-by: mhensen <mh@webvize.nl>
Co-authored-by: ekzhu <ekzhu@users.noreply.github.com>
Co-authored-by: Krzysztof Kasprowicz <60486987+Krzysztof318@users.noreply.github.com>
This commit is contained in:
Xiaoyun Zhang
2024-04-26 09:21:46 -07:00
committed by GitHub
parent fbcc56c90e
commit 600bd3f2fe
226 changed files with 16125 additions and 22 deletions

View File

@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DefaultReplyAgent.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class DefaultReplyAgent : IAgent
{
public DefaultReplyAgent(
string name,
string? defaultReply)
{
Name = name;
DefaultReply = defaultReply ?? string.Empty;
}
public string Name { get; }
public string DefaultReply { get; } = string.Empty;
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> _,
GenerateReplyOptions? __ = null,
CancellationToken ___ = default)
{
return new TextMessage(Role.Assistant, DefaultReply, from: this.Name);
}
}

View File

@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GroupChatManager.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class GroupChatManager : IAgent
{
public GroupChatManager(IGroupChat groupChat)
{
GroupChat = groupChat;
}
public string Name => throw new ArgumentException("GroupChatManager does not have a name");
public IEnumerable<IMessage>? Messages { get; private set; }
public IGroupChat GroupChat { get; }
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options,
CancellationToken cancellationToken = default)
{
var response = await GroupChat.CallAsync(messages, ct: cancellationToken);
Messages = response;
return response.Last();
}
}

View File

@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IAgent.cs
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public interface IAgent
{
public string Name { get; }
/// <summary>
/// Generate reply
/// </summary>
/// <param name="messages">conversation history</param>
/// <param name="options">completion option. If provided, it should override existing option if there's any</param>
public Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
}
public class GenerateReplyOptions
{
public GenerateReplyOptions()
{
}
/// <summary>
/// Copy constructor
/// </summary>
/// <param name="other">other option to copy from</param>
public GenerateReplyOptions(GenerateReplyOptions other)
{
this.Temperature = other.Temperature;
this.MaxToken = other.MaxToken;
this.StopSequence = other.StopSequence?.Select(s => s)?.ToArray();
this.Functions = other.Functions?.Select(f => f)?.ToArray();
}
public float? Temperature { get; set; }
public int? MaxToken { get; set; }
public string[]? StopSequence { get; set; }
public FunctionContract[]? Functions { get; set; }
}

View File

@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMiddlewareAgent.cs
using System.Collections.Generic;
namespace AutoGen.Core;
public interface IMiddlewareAgent : IAgent
{
/// <summary>
/// Get the inner agent.
/// </summary>
IAgent Agent { get; }
/// <summary>
/// Get the middlewares.
/// </summary>
IEnumerable<IMiddleware> Middlewares { get; }
/// <summary>
/// Use middleware.
/// </summary>
void Use(IMiddleware middleware);
}
public interface IMiddlewareStreamAgent : IMiddlewareAgent, IStreamingAgent
{
/// <summary>
/// Get the inner agent.
/// </summary>
IStreamingAgent StreamingAgent { get; }
IEnumerable<IStreamingMiddleware> StreamingMiddlewares { get; }
void UseStreaming(IStreamingMiddleware middleware);
}
public interface IMiddlewareAgent<out T> : IMiddlewareAgent
where T : IAgent
{
/// <summary>
/// Get the typed inner agent.
/// </summary>
T TAgent { get; }
}
public interface IMiddlewareStreamAgent<out T> : IMiddlewareStreamAgent, IMiddlewareAgent<T>
where T : IStreamingAgent
{
}

View File

@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IStreamingAgent.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// agent that supports streaming reply
/// </summary>
public interface IStreamingAgent : IAgent
{
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default);
}

View File

@@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MiddlewareAgent.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// An agent that allows you to add middleware and modify the behavior of an existing agent.
/// </summary>
public class MiddlewareAgent : IMiddlewareAgent
{
private readonly IAgent _agent;
private readonly List<IMiddleware> middlewares = new();
/// <summary>
/// Create a new instance of <see cref="MiddlewareAgent"/>
/// </summary>
/// <param name="innerAgent">the inner agent where middleware will be added.</param>
/// <param name="name">the name of the agent if provided. Otherwise, the name of <paramref name="innerAgent"/> will be used.</param>
public MiddlewareAgent(IAgent innerAgent, string? name = null)
{
this.Name = name ?? innerAgent.Name;
this._agent = innerAgent;
}
/// <summary>
/// Create a new instance of <see cref="MiddlewareAgent"/> by copying the middlewares from another <see cref="MiddlewareAgent"/>.
/// </summary>
public MiddlewareAgent(MiddlewareAgent other)
{
this.Name = other.Name;
this._agent = other._agent;
this.middlewares.AddRange(other.middlewares);
}
public string Name { get; }
/// <summary>
/// Get the inner agent.
/// </summary>
public IAgent Agent => this._agent;
/// <summary>
/// Get the middlewares.
/// </summary>
public IEnumerable<IMiddleware> Middlewares => this.middlewares;
public Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
IAgent agent = this._agent;
foreach (var middleware in this.middlewares)
{
agent = new DelegateAgent(middleware, agent);
}
return agent.GenerateReplyAsync(messages, options, cancellationToken);
}
/// <summary>
/// Add a middleware to the agent. If multiple middlewares are added, they will be executed in the LIFO order.
/// Call into the next function to continue the execution of the next middleware.
/// Short cut middleware execution by not calling into the next function.
/// </summary>
public void Use(Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func, string? middlewareName = null)
{
this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
}));
}
public void Use(IMiddleware middleware)
{
this.middlewares.Add(middleware);
}
public override string ToString()
{
var names = this.Middlewares.Select(m => m.Name ?? "[Unknown middleware]");
var namesPlusAgentName = names.Append(this.Name);
return namesPlusAgentName.Aggregate((a, b) => $"{a} -> {b}");
}
private class DelegateAgent : IAgent
{
private readonly IAgent innerAgent;
private readonly IMiddleware middleware;
public DelegateAgent(IMiddleware middleware, IAgent innerAgent)
{
this.middleware = middleware;
this.innerAgent = innerAgent;
}
public string Name { get => this.innerAgent.Name; }
public Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var context = new MiddlewareContext(messages, options);
return this.middleware.InvokeAsync(context, this.innerAgent, cancellationToken);
}
}
}
public sealed class MiddlewareAgent<T> : MiddlewareAgent, IMiddlewareAgent<T>
where T : IAgent
{
public MiddlewareAgent(T innerAgent, string? name = null)
: base(innerAgent, name)
{
this.TAgent = innerAgent;
}
public MiddlewareAgent(MiddlewareAgent<T> other)
: base(other)
{
this.TAgent = other.TAgent;
}
/// <summary>
/// Get the inner agent of type <typeparamref name="T"/>.
/// </summary>
public T TAgent { get; }
}

View File

@@ -0,0 +1,124 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MiddlewareStreamingAgent.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent
{
private readonly IStreamingAgent _agent;
private readonly List<IStreamingMiddleware> _streamingMiddlewares = new();
private readonly List<IMiddleware> _middlewares = new();
public MiddlewareStreamingAgent(
IStreamingAgent agent,
string? name = null,
IEnumerable<IStreamingMiddleware>? streamingMiddlewares = null,
IEnumerable<IMiddleware>? middlewares = null)
: base(agent, name)
{
_agent = agent;
if (streamingMiddlewares != null)
{
_streamingMiddlewares.AddRange(streamingMiddlewares);
}
if (middlewares != null)
{
_middlewares.AddRange(middlewares);
}
}
/// <summary>
/// Get the inner agent.
/// </summary>
public IStreamingAgent StreamingAgent => _agent;
/// <summary>
/// Get the streaming middlewares.
/// </summary>
public IEnumerable<IStreamingMiddleware> StreamingMiddlewares => _streamingMiddlewares;
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var agent = _agent;
foreach (var middleware in _streamingMiddlewares)
{
agent = new DelegateStreamingAgent(middleware, agent);
}
return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
public void UseStreaming(IStreamingMiddleware middleware)
{
_streamingMiddlewares.Add(middleware);
}
private class DelegateStreamingAgent : IStreamingAgent
{
private IStreamingMiddleware? streamingMiddleware;
private IMiddleware? middleware;
private IStreamingAgent innerAgent;
public string Name => innerAgent.Name;
public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent next)
{
this.streamingMiddleware = middleware;
this.innerAgent = next;
}
public DelegateStreamingAgent(IMiddleware middleware, IStreamingAgent next)
{
this.middleware = middleware;
this.innerAgent = next;
}
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (middleware is null)
{
return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}
var context = new MiddlewareContext(messages, options);
return await middleware.InvokeAsync(context, innerAgent, cancellationToken);
}
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
if (streamingMiddleware is null)
{
return innerAgent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
var context = new MiddlewareContext(messages, options);
return streamingMiddleware.InvokeAsync(context, innerAgent, cancellationToken);
}
}
}
public sealed class MiddlewareStreamingAgent<T> : MiddlewareStreamingAgent, IMiddlewareStreamAgent<T>
where T : IStreamingAgent
{
public MiddlewareStreamingAgent(T innerAgent, string? name = null)
: base(innerAgent, name)
{
TAgent = innerAgent;
}
public MiddlewareStreamingAgent(MiddlewareStreamingAgent<T> other)
: base(other)
{
TAgent = other.TAgent;
}
/// <summary>
/// Get the inner agent.
/// </summary>
public T TAgent { get; }
}

View File

@@ -0,0 +1,21 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>AutoGen.Core</RootNamespace>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.Core</Title>
<Description>
Core library for AutoGen. This package provides contracts and core functionalities for AutoGen.
</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="JsonSchema.Net.Generation" Version="$(JsonSchemaVersion)" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,174 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentExtension.cs
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public static class AgentExtension
{
/// <summary>
/// Send message to an agent.
/// </summary>
/// <param name="message">message to send. will be added to the end of <paramref name="chatHistory"/> if provided </param>
/// <param name="agent">sender agent.</param>
/// <param name="chatHistory">chat history.</param>
/// <returns>conversation history</returns>
public static async Task<IMessage> SendAsync(
this IAgent agent,
IMessage? message = null,
IEnumerable<IMessage>? chatHistory = null,
CancellationToken ct = default)
{
var messages = new List<IMessage>();
if (chatHistory != null)
{
messages.AddRange(chatHistory);
}
if (message != null)
{
messages.Add(message);
}
var result = await agent.GenerateReplyAsync(messages, cancellationToken: ct);
return result;
}
/// <summary>
/// Send message to an agent.
/// </summary>
/// <param name="agent">sender agent.</param>
/// <param name="message">message to send. will be added to the end of <paramref name="chatHistory"/> if provided </param>
/// <param name="chatHistory">chat history.</param>
/// <returns>conversation history</returns>
public static async Task<IMessage> SendAsync(
this IAgent agent,
string message,
IEnumerable<IMessage>? chatHistory = null,
CancellationToken ct = default)
{
var msg = new TextMessage(Role.User, message);
return await agent.SendAsync(msg, chatHistory, ct);
}
/// <summary>
/// Send message to another agent.
/// </summary>
/// <param name="agent">sender agent.</param>
/// <param name="receiver">receiver agent.</param>
/// <param name="chatHistory">chat history.</param>
/// <param name="maxRound">max conversation round.</param>
/// <returns>conversation history</returns>
public static async Task<IEnumerable<IMessage>> SendAsync(
this IAgent agent,
IAgent receiver,
IEnumerable<IMessage> chatHistory,
int maxRound = 10,
CancellationToken ct = default)
{
if (receiver is GroupChatManager manager)
{
var gc = manager.GroupChat;
return await agent.SendMessageToGroupAsync(gc, chatHistory, maxRound, ct);
}
var groupChat = new RoundRobinGroupChat(
agents: new[]
{
agent,
receiver,
});
return await groupChat.CallAsync(chatHistory, maxRound, ct: ct);
}
/// <summary>
/// Send message to another agent.
/// </summary>
/// <param name="agent">sender agent.</param>
/// <param name="message">message to send. will be added to the end of <paramref name="chatHistory"/> if provided </param>
/// <param name="receiver">receiver agent.</param>
/// <param name="chatHistory">chat history.</param>
/// <param name="maxRound">max conversation round.</param>
/// <returns>conversation history</returns>
public static async Task<IEnumerable<IMessage>> SendAsync(
this IAgent agent,
IAgent receiver,
string message,
IEnumerable<IMessage>? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
var msg = new TextMessage(Role.User, message)
{
From = agent.Name,
};
chatHistory = chatHistory ?? new List<IMessage>();
chatHistory = chatHistory.Append(msg);
return await agent.SendAsync(receiver, chatHistory, maxRound, ct);
}
/// <summary>
/// Shortcut API to send message to another agent.
/// </summary>
/// <param name="agent">sender agent</param>
/// <param name="receiver">receiver agent</param>
/// <param name="message">message to send</param>
/// <param name="maxRound">max round</param>
public static async Task<IEnumerable<IMessage>> InitiateChatAsync(
this IAgent agent,
IAgent receiver,
string? message = null,
int maxRound = 10,
CancellationToken ct = default)
{
var chatHistory = new List<IMessage>();
if (message != null)
{
var msg = new TextMessage(Role.User, message)
{
From = agent.Name,
};
chatHistory.Add(msg);
}
return await agent.SendAsync(receiver, chatHistory, maxRound, ct);
}
public static async Task<IEnumerable<IMessage>> SendMessageToGroupAsync(
this IAgent agent,
IGroupChat groupChat,
string msg,
IEnumerable<IMessage>? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
var chatMessage = new TextMessage(Role.Assistant, msg, from: agent.Name);
chatHistory = chatHistory ?? Enumerable.Empty<IMessage>();
chatHistory = chatHistory.Append(chatMessage);
return await agent.SendMessageToGroupAsync(groupChat, chatHistory, maxRound, ct);
}
public static async Task<IEnumerable<IMessage>> SendMessageToGroupAsync(
this IAgent _,
IGroupChat groupChat,
IEnumerable<IMessage>? chatHistory = null,
int maxRound = 10,
CancellationToken ct = default)
{
return await groupChat.CallAsync(chatHistory, maxRound, ct);
}
}

View File

@@ -0,0 +1,109 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GroupChatExtension.cs
using System;
using System.Collections.Generic;
using System.Linq;
namespace AutoGen.Core;
public static class GroupChatExtension
{
public const string TERMINATE = "[GROUPCHAT_TERMINATE]";
public const string CLEAR_MESSAGES = "[GROUPCHAT_CLEAR_MESSAGES]";
[Obsolete("please use SendIntroduction")]
public static void AddInitializeMessage(this IAgent agent, string message, IGroupChat groupChat)
{
var msg = new TextMessage(Role.User, message)
{
From = agent.Name
};
groupChat.SendIntroduction(msg);
}
/// <summary>
/// Send an instruction message to the group chat.
/// </summary>
public static void SendIntroduction(this IAgent agent, string message, IGroupChat groupChat)
{
var msg = new TextMessage(Role.User, message)
{
From = agent.Name
};
groupChat.SendIntroduction(msg);
}
public static IEnumerable<IMessage> MessageToKeep(
this IGroupChat _,
IEnumerable<IMessage> messages)
{
var lastCLRMessageIndex = messages.ToList()
.FindLastIndex(x => x.IsGroupChatClearMessage());
// if multiple clr messages, e.g [msg, clr, msg, clr, msg, clr, msg]
// only keep the the messages after the second last clr message.
if (messages.Count(m => m.IsGroupChatClearMessage()) > 1)
{
lastCLRMessageIndex = messages.ToList()
.FindLastIndex(lastCLRMessageIndex - 1, lastCLRMessageIndex - 1, x => x.IsGroupChatClearMessage());
messages = messages.Skip(lastCLRMessageIndex);
}
lastCLRMessageIndex = messages.ToList()
.FindLastIndex(x => x.IsGroupChatClearMessage());
if (lastCLRMessageIndex != -1 && messages.Count() - lastCLRMessageIndex >= 2)
{
messages = messages.Skip(lastCLRMessageIndex);
}
return messages;
}
/// <summary>
/// Return true if <see cref="IMessage"/> contains <see cref="TERMINATE"/>, otherwise false.
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public static bool IsGroupChatTerminateMessage(this IMessage message)
{
return message.GetContent()?.Contains(TERMINATE) ?? false;
}
public static bool IsGroupChatClearMessage(this IMessage message)
{
return message.GetContent()?.Contains(CLEAR_MESSAGES) ?? false;
}
public static IEnumerable<IMessage> ProcessConversationForAgent(
this IGroupChat groupChat,
IEnumerable<IMessage> initialMessages,
IEnumerable<IMessage> messages)
{
messages = groupChat.MessageToKeep(messages);
return initialMessages.Concat(messages);
}
internal static IEnumerable<IMessage> ProcessConversationsForRolePlay(
this IGroupChat groupChat,
IEnumerable<IMessage> initialMessages,
IEnumerable<IMessage> messages)
{
messages = groupChat.MessageToKeep(messages);
var messagesToKeep = initialMessages.Concat(messages);
return messagesToKeep.Select((x, i) =>
{
var msg = @$"From {x.From}:
{x.GetContent()}
<eof_msg>
round #
{i}";
return new TextMessage(Role.User, content: msg);
});
}
}

View File

@@ -0,0 +1,213 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MessageExtension.cs
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace AutoGen.Core;
public static class MessageExtension
{
private static string separator = new string('-', 20);
public static string FormatMessage(this IMessage message)
{
return message switch
{
Message msg => msg.FormatMessage(),
TextMessage textMessage => textMessage.FormatMessage(),
ImageMessage imageMessage => imageMessage.FormatMessage(),
ToolCallMessage toolCallMessage => toolCallMessage.FormatMessage(),
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.FormatMessage(),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.FormatMessage(),
_ => message.ToString(),
};
}
public static string FormatMessage(this TextMessage message)
{
var sb = new StringBuilder();
// write from
sb.AppendLine($"TextMessage from {message.From}");
// write a seperator
sb.AppendLine(separator);
sb.AppendLine(message.Content);
// write a seperator
sb.AppendLine(separator);
return sb.ToString();
}
public static string FormatMessage(this ImageMessage message)
{
var sb = new StringBuilder();
// write from
sb.AppendLine($"ImageMessage from {message.From}");
// write a seperator
sb.AppendLine(separator);
sb.AppendLine($"Image: {message.Url}");
// write a seperator
sb.AppendLine(separator);
return sb.ToString();
}
public static string FormatMessage(this ToolCallMessage message)
{
var sb = new StringBuilder();
// write from
sb.AppendLine($"ToolCallMessage from {message.From}");
// write a seperator
sb.AppendLine(separator);
foreach (var toolCall in message.ToolCalls)
{
sb.AppendLine($"- {toolCall.FunctionName}: {toolCall.FunctionArguments}");
}
sb.AppendLine(separator);
return sb.ToString();
}
public static string FormatMessage(this ToolCallResultMessage message)
{
var sb = new StringBuilder();
// write from
sb.AppendLine($"ToolCallResultMessage from {message.From}");
// write a seperator
sb.AppendLine(separator);
foreach (var toolCall in message.ToolCalls)
{
sb.AppendLine($"- {toolCall.FunctionName}: {toolCall.Result}");
}
sb.AppendLine(separator);
return sb.ToString();
}
public static string FormatMessage(this AggregateMessage<ToolCallMessage, ToolCallResultMessage> message)
{
var sb = new StringBuilder();
// write from
sb.AppendLine($"AggregateMessage from {message.From}");
// write a seperator
sb.AppendLine(separator);
sb.AppendLine("ToolCallMessage:");
sb.AppendLine(message.Message1.FormatMessage());
sb.AppendLine("ToolCallResultMessage:");
sb.AppendLine(message.Message2.FormatMessage());
sb.AppendLine(separator);
return sb.ToString();
}
public static string FormatMessage(this Message message)
{
var sb = new StringBuilder();
// write from
sb.AppendLine($"Message from {message.From}");
// write a seperator
sb.AppendLine(separator);
// write content
sb.AppendLine($"content: {message.Content}");
// write function name if exists
if (!string.IsNullOrEmpty(message.FunctionName))
{
sb.AppendLine($"function name: {message.FunctionName}");
sb.AppendLine($"function arguments: {message.FunctionArguments}");
}
// write metadata
if (message.Metadata is { Count: > 0 })
{
sb.AppendLine($"metadata:");
foreach (var item in message.Metadata)
{
sb.AppendLine($"{item.Key}: {item.Value}");
}
}
// write a seperator
sb.AppendLine(separator);
return sb.ToString();
}
public static bool IsSystemMessage(this IMessage message)
{
return message switch
{
TextMessage textMessage => textMessage.Role == Role.System,
Message msg => msg.Role == Role.System,
_ => false,
};
}
/// <summary>
/// Get the content from the message
/// <para>if the message is a <see cref="Message"/> or <see cref="TextMessage"/>, return the content</para>
/// <para>if the message is a <see cref="ToolCallResultMessage"/> and only contains one function call, return the result of that function call</para>
/// <para>if the message is a <see cref="AggregateMessage{ToolCallMessage, ToolCallResultMessage}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/> and the second message only contains one function call, return the result of that function call</para>
/// <para>for all other situation, return null.</para>
/// </summary>
/// <param name="message"></param>
public static string? GetContent(this IMessage message)
{
return message switch
{
TextMessage textMessage => textMessage.Content,
Message msg => msg.Content,
ToolCallResultMessage toolCallResultMessage => toolCallResultMessage.ToolCalls.Count == 1 ? toolCallResultMessage.ToolCalls.First().Result : null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message2.ToolCalls.Count == 1 ? aggregateMessage.Message2.ToolCalls.First().Result : null,
_ => null,
};
}
/// <summary>
/// Get the role from the message if it's available.
/// </summary>
public static Role? GetRole(this IMessage message)
{
return message switch
{
TextMessage textMessage => textMessage.Role,
Message msg => msg.Role,
ImageMessage img => img.Role,
MultiModalMessage multiModal => multiModal.Role,
_ => null,
};
}
/// <summary>
/// Return the tool calls from the message if it's available.
/// <para>if the message is a <see cref="ToolCallMessage"/>, return its tool calls</para>
/// <para>if the message is a <see cref="Message"/> and the function name and function arguments are available, return a list of tool call with one item</para>
/// <para>if the message is a <see cref="AggregateMessage{ToolCallMessage, ToolCallResultMessage}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/>, return the tool calls from the first message</para>
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public static IList<ToolCall>? GetToolCalls(this IMessage message)
{
return message switch
{
ToolCallMessage toolCallMessage => toolCallMessage.ToolCalls,
Message msg => msg.FunctionName is not null && msg.FunctionArguments is not null
? msg.Content is not null ? new List<ToolCall> { new ToolCall(msg.FunctionName, msg.FunctionArguments, result: msg.Content) }
: new List<ToolCall> { new ToolCall(msg.FunctionName, msg.FunctionArguments) }
: null,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => aggregateMessage.Message1.ToolCalls,
_ => null,
};
}
}

View File

@@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MiddlewareExtension.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public static class MiddlewareExtension
{
/// <summary>
/// Register a auto reply hook to an agent. The hook will be called before the agent generate the reply.
/// If the hook return a non-null reply, then that non-null reply will be returned directly without calling the agent.
/// Otherwise, the agent will generate the reply.
/// This is useful when you want to override the agent reply in some cases.
/// </summary>
/// <param name="agent"></param>
/// <param name="replyFunc"></param>
/// <returns></returns>
/// <exception cref="Exception">throw when agent name is null.</exception>
public static MiddlewareAgent<TAgent> RegisterReply<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, CancellationToken, Task<IMessage?>> replyFunc)
where TAgent : IAgent
{
return agent.RegisterMiddleware(async (messages, options, agent, ct) =>
{
var reply = await replyFunc(messages, ct);
if (reply != null)
{
return reply;
}
return await agent.GenerateReplyAsync(messages, options, ct);
});
}
/// <summary>
/// Register a post process hook to an agent. The hook will be called before the agent return the reply and after the agent generate the reply.
/// This is useful when you want to customize arbitrary behavior before the agent return the reply.
///
/// One example is <see cref="PrintMessageMiddlewareExtension.RegisterPrintMessage{TAgent}(TAgent)" />, which print the formatted message to console before the agent return the reply.
/// </summary>
/// <exception cref="Exception">throw when agent name is null.</exception>
public static MiddlewareAgent<TAgent> RegisterPostProcess<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, IMessage, CancellationToken, Task<IMessage>> postprocessFunc)
where TAgent : IAgent
{
return agent.RegisterMiddleware(async (messages, options, agent, ct) =>
{
var reply = await agent.GenerateReplyAsync(messages, options, ct);
return await postprocessFunc(messages, reply, ct);
});
}
/// <summary>
/// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply.
/// </summary>
/// <exception cref="Exception">throw when agent name is null.</exception>
public static MiddlewareAgent<TAgent> RegisterPreProcess<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, CancellationToken, Task<IEnumerable<IMessage>>> preprocessFunc)
where TAgent : IAgent
{
return agent.RegisterMiddleware(async (messages, options, agent, ct) =>
{
var newMessages = await preprocessFunc(messages, ct);
return await agent.GenerateReplyAsync(newMessages, options, ct);
});
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this TAgent agent,
Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func,
string? middlewareName = null)
where TAgent : IAgent
{
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
});
return agent.RegisterMiddleware(middleware);
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this TAgent agent,
IMiddleware middleware)
where TAgent : IAgent
{
var middlewareAgent = new MiddlewareAgent<TAgent>(agent);
return middlewareAgent.RegisterMiddleware(middleware);
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this MiddlewareAgent<TAgent> agent,
Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func,
string? middlewareName = null)
where TAgent : IAgent
{
var delegateMiddleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
});
return agent.RegisterMiddleware(delegateMiddleware);
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterMiddleware<TAgent>(
this MiddlewareAgent<TAgent> agent,
IMiddleware middleware)
where TAgent : IAgent
{
var copyAgent = new MiddlewareAgent<TAgent>(agent);
copyAgent.Use(middleware);
return copyAgent;
}
}

View File

@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// PrintMessageMiddlewareExtension.cs
using System;
namespace AutoGen.Core;
public static class PrintMessageMiddlewareExtension
{
[Obsolete("This API will be removed in v0.1.0, Use RegisterPrintMessage instead.")]
public static MiddlewareAgent<TAgent> RegisterPrintFormatMessageHook<TAgent>(this TAgent agent)
where TAgent : IAgent
{
return RegisterPrintMessage(agent);
}
[Obsolete("This API will be removed in v0.1.0, Use RegisterPrintMessage instead.")]
public static MiddlewareAgent<TAgent> RegisterPrintFormatMessageHook<TAgent>(this MiddlewareAgent<TAgent> agent)
where TAgent : IAgent
{
return RegisterPrintMessage(agent);
}
[Obsolete("This API will be removed in v0.1.0, Use RegisterPrintMessage instead.")]
public static MiddlewareStreamingAgent<TAgent> RegisterPrintFormatMessageHook<TAgent>(this MiddlewareStreamingAgent<TAgent> agent)
where TAgent : IStreamingAgent
{
return RegisterPrintMessage(agent);
}
/// <summary>
/// Register a <see cref="PrintMessageMiddleware"/> to <paramref name="agent"/> which print formatted message to console.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterPrintMessage<TAgent>(this TAgent agent)
where TAgent : IAgent
{
var middleware = new PrintMessageMiddleware();
var middlewareAgent = new MiddlewareAgent<TAgent>(agent);
middlewareAgent.Use(middleware);
return middlewareAgent;
}
/// <summary>
/// Register a <see cref="PrintMessageMiddleware"/> to <paramref name="agent"/> which print formatted message to console.
/// </summary>
public static MiddlewareAgent<TAgent> RegisterPrintMessage<TAgent>(this MiddlewareAgent<TAgent> agent)
where TAgent : IAgent
{
var middleware = new PrintMessageMiddleware();
var middlewareAgent = new MiddlewareAgent<TAgent>(agent);
middlewareAgent.Use(middleware);
return middlewareAgent;
}
/// <summary>
/// Register a <see cref="PrintMessageMiddleware"/> to <paramref name="agent"/> which print formatted message to console.
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterPrintMessage<TAgent>(this MiddlewareStreamingAgent<TAgent> agent)
where TAgent : IStreamingAgent
{
var middleware = new PrintMessageMiddleware();
var middlewareAgent = new MiddlewareStreamingAgent<TAgent>(agent);
middlewareAgent.Use(middleware);
return middlewareAgent;
}
}

View File

@@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// StreamingMiddlewareExtension.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public static class StreamingMiddlewareExtension
{
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterStreamingMiddleware<TStreamingAgent>(
this TStreamingAgent agent,
IStreamingMiddleware middleware)
where TStreamingAgent : IStreamingAgent
{
var middlewareAgent = new MiddlewareStreamingAgent<TStreamingAgent>(agent);
middlewareAgent.UseStreaming(middleware);
if (middleware is IMiddleware middlewareBase)
{
middlewareAgent.Use(middlewareBase);
}
return middlewareAgent;
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
this MiddlewareStreamingAgent<TAgent> agent,
IStreamingMiddleware middleware)
where TAgent : IStreamingAgent
{
var copyAgent = new MiddlewareStreamingAgent<TAgent>(agent);
copyAgent.UseStreaming(middleware);
if (middleware is IMiddleware middlewareBase)
{
copyAgent.Use(middlewareBase);
}
return copyAgent;
}
/// <summary>
/// Register a middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
this TAgent agent,
Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task<IAsyncEnumerable<IStreamingMessage>>> func,
string? middlewareName = null)
where TAgent : IStreamingAgent
{
var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func));
return agent.RegisterStreamingMiddleware(middleware);
}
/// <summary>
/// Register a streaming middleware to an existing agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TAgent> RegisterStreamingMiddleware<TAgent>(
this MiddlewareStreamingAgent<TAgent> agent,
Func<MiddlewareContext, IStreamingAgent, CancellationToken, Task<IAsyncEnumerable<IStreamingMessage>>> func,
string? middlewareName = null)
where TAgent : IStreamingAgent
{
var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func));
return agent.RegisterStreamingMiddleware(middleware);
}
/// <summary>
/// Register a middleware to an existing streaming agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterMiddleware<TStreamingAgent>(
this MiddlewareStreamingAgent<TStreamingAgent> streamingAgent,
Func<IEnumerable<IMessage>, GenerateReplyOptions?, IAgent, CancellationToken, Task<IMessage>> func,
string? middlewareName = null)
where TStreamingAgent : IStreamingAgent
{
var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) =>
{
return await func(context.Messages, context.Options, agent, cancellationToken);
});
return streamingAgent.RegisterMiddleware(middleware);
}
/// <summary>
/// Register a middleware to an existing streaming agent and return a new agent with the middleware.
/// </summary>
public static MiddlewareStreamingAgent<TStreamingAgent> RegisterMiddleware<TStreamingAgent>(
this MiddlewareStreamingAgent<TStreamingAgent> streamingAgent,
IMiddleware middleware)
where TStreamingAgent : IStreamingAgent
{
var copyAgent = new MiddlewareStreamingAgent<TStreamingAgent>(streamingAgent);
copyAgent.Use(middleware);
if (middleware is IStreamingMiddleware streamingMiddleware)
{
copyAgent.UseStreaming(streamingMiddleware);
}
return copyAgent;
}
}

View File

@@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionAttribute.cs
using System;
using System.Collections.Generic;
namespace AutoGen.Core;
[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
public class FunctionAttribute : Attribute
{
public string? FunctionName { get; }
public string? Description { get; }
public FunctionAttribute(string? functionName = null, string? description = null)
{
FunctionName = functionName;
Description = description;
}
}
public class FunctionContract
{
/// <summary>
/// The namespace of the function.
/// </summary>
public string? Namespace { get; set; }
/// <summary>
/// The class name of the function.
/// </summary>
public string? ClassName { get; set; }
/// <summary>
/// The name of the function.
/// </summary>
public string? Name { get; set; }
/// <summary>
/// The description of the function.
/// If a structured comment is available, the description will be extracted from the summary section.
/// Otherwise, the description will be null.
/// </summary>
public string? Description { get; set; }
/// <summary>
/// The parameters of the function.
/// </summary>
public IEnumerable<FunctionParameterContract>? Parameters { get; set; }
/// <summary>
/// The return type of the function.
/// </summary>
public Type? ReturnType { get; set; }
/// <summary>
/// The description of the return section.
/// If a structured comment is available, the description will be extracted from the return section.
/// Otherwise, the description will be null.
/// </summary>
public string? ReturnDescription { get; set; }
}
public class FunctionParameterContract
{
/// <summary>
/// The name of the parameter.
/// </summary>
public string? Name { get; set; }
/// <summary>
/// The description of the parameter.
/// This will be extracted from the param section of the structured comment if available.
/// Otherwise, the description will be null.
/// </summary>
public string? Description { get; set; }
/// <summary>
/// The type of the parameter.
/// </summary>
public Type? ParameterType { get; set; }
/// <summary>
/// If the parameter is a required parameter.
/// </summary>
public bool IsRequired { get; set; }
/// <summary>
/// The default value of the parameter.
/// </summary>
public object? DefaultValue { get; set; }
}

View File

@@ -0,0 +1,117 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Workflow.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// Obsolete: please use <see cref="Graph"/>
/// </summary>
[Obsolete("please use Graph")]
public class Workflow : Graph
{
[Obsolete("please use Graph")]
public Workflow(IEnumerable<Transition> transitions)
: base(transitions)
{
}
}
public class Graph
{
private readonly List<Transition> transitions = new List<Transition>();
public Graph(IEnumerable<Transition> transitions)
{
this.transitions.AddRange(transitions);
}
public void AddTransition(Transition transition)
{
transitions.Add(transition);
}
/// <summary>
/// Get the transitions of the workflow.
/// </summary>
public IEnumerable<Transition> Transitions => transitions;
/// <summary>
/// Get the next available agents that the messages can be transit to.
/// </summary>
/// <param name="fromAgent">the from agent</param>
/// <param name="messages">messages</param>
/// <returns>A list of agents that the messages can be transit to</returns>
public async Task<IEnumerable<IAgent>> TransitToNextAvailableAgentsAsync(IAgent fromAgent, IEnumerable<IMessage> messages)
{
var nextAgents = new List<IAgent>();
var availableTransitions = transitions.FindAll(t => t.From == fromAgent) ?? Enumerable.Empty<Transition>();
foreach (var transition in availableTransitions)
{
if (await transition.CanTransitionAsync(messages))
{
nextAgents.Add(transition.To);
}
}
return nextAgents;
}
}
/// <summary>
/// Represents a transition between two agents.
/// </summary>
public class Transition
{
private readonly IAgent _from;
private readonly IAgent _to;
private readonly Func<IAgent, IAgent, IEnumerable<IMessage>, Task<bool>>? _canTransition;
/// <summary>
/// Create a new instance of <see cref="Transition"/>.
/// This constructor is used for testing purpose only.
/// To create a new instance of <see cref="Transition"/>, use <see cref="Transition.Create{TFromAgent, TToAgent}(TFromAgent, TToAgent, Func{TFromAgent, TToAgent, IEnumerable{IMessage}, Task{bool}}?)"/>.
/// </summary>
/// <param name="from">from agent</param>
/// <param name="to">to agent</param>
/// <param name="canTransitionAsync">detect if the transition is allowed, default to be always true</param>
internal Transition(IAgent from, IAgent to, Func<IAgent, IAgent, IEnumerable<IMessage>, Task<bool>>? canTransitionAsync = null)
{
_from = from;
_to = to;
_canTransition = canTransitionAsync;
}
/// <summary>
/// Create a new instance of <see cref="Transition"/>.
/// </summary>
/// <returns><see cref="Transition"/></returns>"
public static Transition Create<TFromAgent, TToAgent>(TFromAgent from, TToAgent to, Func<TFromAgent, TToAgent, IEnumerable<IMessage>, Task<bool>>? canTransitionAsync = null)
where TFromAgent : IAgent
where TToAgent : IAgent
{
return new Transition(from, to, (fromAgent, toAgent, messages) => canTransitionAsync?.Invoke((TFromAgent)fromAgent, (TToAgent)toAgent, messages) ?? Task.FromResult(true));
}
public IAgent From => _from;
public IAgent To => _to;
/// <summary>
/// Check if the transition is allowed.
/// </summary>
/// <param name="messages">messages</param>
public Task<bool> CanTransitionAsync(IEnumerable<IMessage> messages)
{
if (_canTransition == null)
{
return Task.FromResult(true);
}
return _canTransition(this.From, this.To, messages);
}
}

View File

@@ -0,0 +1,183 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GroupChat.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public class GroupChat : IGroupChat
{
private IAgent? admin;
private List<IAgent> agents = new List<IAgent>();
private IEnumerable<IMessage> initializeMessages = new List<IMessage>();
private Graph? workflow = null;
public IEnumerable<IMessage>? Messages { get; private set; }
/// <summary>
/// Create a group chat. The next speaker will be decided by a combination effort of the admin and the workflow.
/// </summary>
/// <param name="admin">admin agent. If provided, the admin will be invoked to decide the next speaker.</param>
/// <param name="workflow">workflow of the group chat. If provided, the next speaker will be decided by the workflow.</param>
/// <param name="members">group members.</param>
/// <param name="initializeMessages"></param>
public GroupChat(
IEnumerable<IAgent> members,
IAgent? admin = null,
IEnumerable<IMessage>? initializeMessages = null,
Graph? workflow = null)
{
this.admin = admin;
this.agents = members.ToList();
this.initializeMessages = initializeMessages ?? new List<IMessage>();
this.workflow = workflow;
this.Validation();
}
private void Validation()
{
// check if all agents has a name
if (this.agents.Any(x => string.IsNullOrEmpty(x.Name)))
{
throw new Exception("All agents must have a name.");
}
// check if any agents has the same name
var names = this.agents.Select(x => x.Name).ToList();
if (names.Distinct().Count() != names.Count)
{
throw new Exception("All agents must have a unique name.");
}
// if there's a workflow
// check if the agents in that workflow are in the group chat
if (this.workflow != null)
{
var agentNamesInWorkflow = this.workflow.Transitions.Select(x => x.From.Name!).Concat(this.workflow.Transitions.Select(x => x.To.Name!)).Distinct();
if (agentNamesInWorkflow.Any(x => !this.agents.Select(a => a.Name).Contains(x)))
{
throw new Exception("All agents in the workflow must be in the group chat.");
}
}
// must provide one of admin or workflow
if (this.admin == null && this.workflow == null)
{
throw new Exception("Must provide one of admin or workflow.");
}
}
/// <summary>
/// Select the next speaker based on the conversation history.
/// The next speaker will be decided by a combination effort of the admin and the workflow.
/// Firstly, a group of candidates will be selected by the workflow. If there's only one candidate, then that candidate will be the next speaker.
/// Otherwise, the admin will be invoked to decide the next speaker using role-play prompt.
/// </summary>
/// <param name="currentSpeaker">current speaker</param>
/// <param name="conversationHistory">conversation history</param>
/// <returns>next speaker.</returns>
public async Task<IAgent> SelectNextSpeakerAsync(IAgent currentSpeaker, IEnumerable<IMessage> conversationHistory)
{
var agentNames = this.agents.Select(x => x.Name).ToList();
if (this.workflow != null)
{
var nextAvailableAgents = await this.workflow.TransitToNextAvailableAgentsAsync(currentSpeaker, conversationHistory);
agentNames = nextAvailableAgents.Select(x => x.Name).ToList();
if (agentNames.Count() == 0)
{
throw new Exception("No next available agents found in the current workflow");
}
if (agentNames.Count() == 1)
{
return this.agents.FirstOrDefault(x => x.Name == agentNames.First());
}
}
if (this.admin == null)
{
throw new Exception("No admin is provided.");
}
var systemMessage = new TextMessage(Role.System,
content: $@"You are in a role play game. Carefully read the conversation history and carry on the conversation.
The available roles are:
{string.Join(",", agentNames)}
Each message will start with 'From name:', e.g:
From admin:
//your message//.");
var conv = this.ProcessConversationsForRolePlay(this.initializeMessages, conversationHistory);
var messages = new IMessage[] { systemMessage }.Concat(conv);
var response = await this.admin.GenerateReplyAsync(
messages: messages,
options: new GenerateReplyOptions
{
Temperature = 0,
MaxToken = 128,
StopSequence = [":"],
Functions = [],
});
var name = response?.GetContent() ?? throw new Exception("No name is returned.");
// remove From
name = name!.Substring(5);
return this.agents.First(x => x.Name!.ToLower() == name.ToLower());
}
/// <inheritdoc />
public void AddInitializeMessage(IMessage message)
{
this.SendIntroduction(message);
}
public async Task<IEnumerable<IMessage>> CallAsync(
IEnumerable<IMessage>? conversationWithName = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List<IMessage>();
if (conversationWithName != null)
{
conversationHistory.AddRange(conversationWithName);
}
var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
{
null => this.agents.First(),
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
};
var round = 0;
while (round < maxRound)
{
var currentSpeaker = await this.SelectNextSpeakerAsync(lastSpeaker, conversationHistory);
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
conversationHistory.Add(result);
// if message is terminate message, then terminate the conversation
if (result?.IsGroupChatTerminateMessage() ?? false)
{
break;
}
lastSpeaker = currentSpeaker;
round++;
}
return conversationHistory;
}
public void SendIntroduction(IMessage message)
{
this.initializeMessages = this.initializeMessages.Append(message);
}
}

View File

@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// RoundRobinGroupChat.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// Obsolete: please use <see cref="RoundRobinGroupChat"/>
/// </summary>
[Obsolete("please use RoundRobinGroupChat")]
public class SequentialGroupChat : RoundRobinGroupChat
{
[Obsolete("please use RoundRobinGroupChat")]
public SequentialGroupChat(IEnumerable<IAgent> agents, List<IMessage>? initializeMessages = null)
: base(agents, initializeMessages)
{
}
}
/// <summary>
/// A group chat that allows agents to talk in a round-robin manner.
/// </summary>
public class RoundRobinGroupChat : IGroupChat
{
private readonly List<IAgent> agents = new List<IAgent>();
private readonly List<IMessage> initializeMessages = new List<IMessage>();
public RoundRobinGroupChat(
IEnumerable<IAgent> agents,
List<IMessage>? initializeMessages = null)
{
this.agents.AddRange(agents);
this.initializeMessages = initializeMessages ?? new List<IMessage>();
}
/// <inheritdoc />
public void AddInitializeMessage(IMessage message)
{
this.SendIntroduction(message);
}
public async Task<IEnumerable<IMessage>> CallAsync(
IEnumerable<IMessage>? conversationWithName = null,
int maxRound = 10,
CancellationToken ct = default)
{
var conversationHistory = new List<IMessage>();
if (conversationWithName != null)
{
conversationHistory.AddRange(conversationWithName);
}
var lastSpeaker = conversationHistory.LastOrDefault()?.From switch
{
null => this.agents.First(),
_ => this.agents.FirstOrDefault(x => x.Name == conversationHistory.Last().From) ?? throw new Exception("The agent is not in the group chat"),
};
var round = 0;
while (round < maxRound)
{
var currentSpeaker = this.SelectNextSpeaker(lastSpeaker);
var processedConversation = this.ProcessConversationForAgent(this.initializeMessages, conversationHistory);
var result = await currentSpeaker.GenerateReplyAsync(processedConversation) ?? throw new Exception("No result is returned.");
conversationHistory.Add(result);
// if message is terminate message, then terminate the conversation
if (result?.IsGroupChatTerminateMessage() ?? false)
{
break;
}
lastSpeaker = currentSpeaker;
round++;
}
return conversationHistory;
}
public void SendIntroduction(IMessage message)
{
this.initializeMessages.Add(message);
}
private IAgent SelectNextSpeaker(IAgent currentSpeaker)
{
var index = this.agents.IndexOf(currentSpeaker);
if (index == -1)
{
throw new ArgumentException("The agent is not in the group chat", nameof(currentSpeaker));
}
var nextIndex = (index + 1) % this.agents.Count;
return this.agents[nextIndex];
}
}

View File

@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IGroupChat.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
public interface IGroupChat
{
/// <summary>
/// Send an introduction message to the group chat.
/// </summary>
void SendIntroduction(IMessage message);
[Obsolete("please use SendIntroduction")]
void AddInitializeMessage(IMessage message);
Task<IEnumerable<IMessage>> CallAsync(IEnumerable<IMessage>? conversation = null, int maxRound = 10, CancellationToken ct = default);
}

View File

@@ -0,0 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ILLMConfig.cs
namespace AutoGen.Core;
public interface ILLMConfig
{
}

View File

@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AggregateMessage.cs
using System;
using System.Collections.Generic;
namespace AutoGen.Core;
public class AggregateMessage<TMessage1, TMessage2> : IMessage
where TMessage1 : IMessage
where TMessage2 : IMessage
{
public AggregateMessage(TMessage1 message1, TMessage2 message2, string? from = null)
{
this.From = from;
this.Message1 = message1;
this.Message2 = message2;
this.Validate();
}
public TMessage1 Message1 { get; }
public TMessage2 Message2 { get; }
public string? From { get; set; }
private void Validate()
{
var messages = new List<IMessage> { this.Message1, this.Message2 };
// the from property of all messages should be the same with the from property of the aggregate message
foreach (var message in messages)
{
if (message.From != this.From)
{
throw new ArgumentException($"The from property of the message {message} is different from the from property of the aggregate message {this}");
}
}
}
public override string ToString()
{
var stringBuilder = new System.Text.StringBuilder();
var messages = new List<IMessage> { this.Message1, this.Message2 };
stringBuilder.Append($"AggregateMessage({this.From})");
foreach (var message in messages)
{
stringBuilder.Append($"\n\t{message}");
}
return stringBuilder.ToString();
}
}

View File

@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMessage.cs
namespace AutoGen.Core;
/// <summary>
/// The universal message interface for all message types in AutoGen.
/// <para>Related PR: https://github.com/microsoft/autogen/pull/1676</para>
/// <para>Built-in message types</para>
/// <list type="bullet">
/// <item>
/// <see cref="TextMessage"/>: plain text message.
/// </item>
/// <item>
/// <see cref="ImageMessage"/>: image message.
/// </item>
/// <item>
/// <see cref="MultiModalMessage"/>: message type for multimodal message. The current support message items are <see cref="TextMessage"/> and <see cref="ImageMessage"/>.
/// </item>
/// <item>
/// <see cref="ToolCallMessage"/>: message type for tool call. This message supports both single and parallel tool call.
/// </item>
/// <item>
/// <see cref="ToolCallResultMessage"/>: message type for tool call result.
/// </item>
/// <item>
/// <see cref="Message"/>: This type is used by previous version of AutoGen. And it's reserved for backward compatibility.
/// </item>
/// <item>
/// <see cref="AggregateMessage{TMessage1, TMessage2}"/>: an aggregate message type that contains two message types.
/// This type is useful when you want to combine two message types into one unique message type. One example is when invoking a tool call and you want to return both <see cref="ToolCallMessage"/> and <see cref="ToolCallResultMessage"/>.
/// One example of how this type is used in AutoGen is <see cref="FunctionCallMiddleware"/>
/// </item>
/// </list>
/// </summary>
public interface IMessage : IStreamingMessage
{
}
public interface IMessage<out T> : IMessage, IStreamingMessage<T>
{
}
public interface IStreamingMessage
{
string? From { get; set; }
}
public interface IStreamingMessage<out T> : IStreamingMessage
{
T Content { get; }
}

View File

@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ImageMessage.cs
using System;
namespace AutoGen.Core;
public class ImageMessage : IMessage
{
public ImageMessage(Role role, string url, string? from = null)
{
this.Role = role;
this.From = from;
this.Url = url;
}
public ImageMessage(Role role, Uri uri, string? from = null)
{
this.Role = role;
this.From = from;
this.Url = uri.ToString();
}
public Role Role { get; set; }
public string Url { get; set; }
public string? From { get; set; }
public override string ToString()
{
return $"ImageMessage({this.Role}, {this.Url}, {this.From})";
}
}

View File

@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Message.cs
using System.Collections.Generic;
namespace AutoGen.Core;
public class Message : IMessage
{
public Message(
Role role,
string? content,
string? from = null,
ToolCall? toolCall = null)
{
this.Role = role;
this.Content = content;
this.From = from;
this.FunctionName = toolCall?.FunctionName;
this.FunctionArguments = toolCall?.FunctionArguments;
}
public Message(Message other)
: this(other.Role, other.Content, other.From)
{
this.FunctionName = other.FunctionName;
this.FunctionArguments = other.FunctionArguments;
this.Value = other.Value;
this.Metadata = other.Metadata;
}
public Role Role { get; set; }
public string? Content { get; set; }
public string? From { get; set; }
public string? FunctionName { get; set; }
public string? FunctionArguments { get; set; }
/// <summary>
/// raw message
/// </summary>
public object? Value { get; set; }
public IList<KeyValuePair<string, object>> Metadata { get; set; } = new List<KeyValuePair<string, object>>();
public override string ToString()
{
return $"Message({this.Role}, {this.Content}, {this.From}, {this.FunctionName}, {this.FunctionArguments})";
}
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MessageEnvelope.cs
using System.Collections.Generic;
namespace AutoGen.Core;
public abstract class MessageEnvelope : IMessage, IStreamingMessage
{
public MessageEnvelope(string? from = null, IDictionary<string, object>? metadata = null)
{
this.From = from;
this.Metadata = metadata ?? new Dictionary<string, object>();
}
public static MessageEnvelope<TContent> Create<TContent>(TContent content, string? from = null, IDictionary<string, object>? metadata = null)
{
return new MessageEnvelope<TContent>(content, from, metadata);
}
public string? From { get; set; }
public IDictionary<string, object> Metadata { get; set; }
}
public class MessageEnvelope<T> : MessageEnvelope, IMessage<T>, IStreamingMessage<T>
{
public MessageEnvelope(T content, string? from = null, IDictionary<string, object>? metadata = null)
: base(from, metadata)
{
this.Content = content;
this.From = from;
this.Metadata = metadata ?? new Dictionary<string, object>();
}
public T Content { get; }
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MultiModalMessage.cs
using System;
using System.Collections.Generic;
namespace AutoGen.Core;
public class MultiModalMessage : IMessage
{
public MultiModalMessage(Role role, IEnumerable<IMessage> content, string? from = null)
{
this.Role = role;
this.Content = content;
this.From = from;
this.Validate();
}
public Role Role { get; set; }
public IEnumerable<IMessage> Content { get; set; }
public string? From { get; set; }
private void Validate()
{
foreach (var message in this.Content)
{
if (message.From != this.From)
{
var reason = $"The from property of the message {message} is different from the from property of the aggregate message {this}";
throw new ArgumentException($"Invalid aggregate message {reason}");
}
}
// all message must be either text or image
foreach (var message in this.Content)
{
if (message is not TextMessage && message is not ImageMessage)
{
var reason = $"The message {message} is not a text or image message";
throw new ArgumentException($"Invalid aggregate message {reason}");
}
}
}
public override string ToString()
{
var stringBuilder = new System.Text.StringBuilder();
stringBuilder.Append($"MultiModalMessage({this.Role}, {this.From})");
foreach (var message in this.Content)
{
stringBuilder.Append($"\n\t{message}");
}
return stringBuilder.ToString();
}
}

View File

@@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Role.cs
using System;
namespace AutoGen.Core;
public readonly struct Role : IEquatable<Role>
{
private readonly string label;
internal Role(string name)
{
label = name;
}
public static Role User { get; } = new Role("user");
public static Role Assistant { get; } = new Role("assistant");
public static Role System { get; } = new Role("system");
public static Role Function { get; } = new Role("function");
public bool Equals(Role other)
{
return label.Equals(other.label, StringComparison.OrdinalIgnoreCase);
}
public override string ToString()
{
return label;
}
public override bool Equals(object? obj)
{
return obj is Role other && Equals(other);
}
public override int GetHashCode()
{
return label.GetHashCode();
}
public static bool operator ==(Role left, Role right)
{
return left.Equals(right);
}
public static bool operator !=(Role left, Role right)
{
return !(left == right);
}
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// TextMessage.cs
namespace AutoGen.Core;
public class TextMessage : IMessage, IStreamingMessage
{
public TextMessage(Role role, string content, string? from = null)
{
this.Content = content;
this.Role = role;
this.From = from;
}
public TextMessage(TextMessageUpdate update)
{
this.Content = update.Content ?? string.Empty;
this.Role = update.Role;
this.From = update.From;
}
public void Update(TextMessageUpdate update)
{
if (update.Role != this.Role)
{
throw new System.ArgumentException("Role mismatch", nameof(update));
}
if (update.From != this.From)
{
throw new System.ArgumentException("From mismatch", nameof(update));
}
this.Content = this.Content + update.Content ?? string.Empty;
}
public Role Role { get; set; }
public string Content { get; set; }
public string? From { get; set; }
public override string ToString()
{
return $"TextMessage({this.Role}, {this.Content}, {this.From})";
}
}
public class TextMessageUpdate : IStreamingMessage
{
public TextMessageUpdate(Role role, string? content, string? from = null)
{
this.Content = content;
this.From = from;
this.Role = role;
}
public string? Content { get; set; }
public string? From { get; set; }
public Role Role { get; set; }
}

View File

@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ToolCallMessage.cs
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace AutoGen.Core;
public class ToolCall
{
public ToolCall(string functionName, string functionArgs)
{
this.FunctionName = functionName;
this.FunctionArguments = functionArgs;
}
public ToolCall(string functionName, string functionArgs, string result)
{
this.FunctionName = functionName;
this.FunctionArguments = functionArgs;
this.Result = result;
}
public string FunctionName { get; set; }
public string FunctionArguments { get; set; }
public string? Result { get; set; }
public override string ToString()
{
return $"ToolCall({this.FunctionName}, {this.FunctionArguments}, {this.Result})";
}
}
public class ToolCallMessage : IMessage
{
public ToolCallMessage(IEnumerable<ToolCall> toolCalls, string? from = null)
{
this.From = from;
this.ToolCalls = toolCalls.ToList();
}
public ToolCallMessage(string functionName, string functionArgs, string? from = null)
{
this.From = from;
this.ToolCalls = new List<ToolCall> { new ToolCall(functionName, functionArgs) };
}
public ToolCallMessage(ToolCallMessageUpdate update)
{
this.From = update.From;
this.ToolCalls = new List<ToolCall> { new ToolCall(update.FunctionName, update.FunctionArgumentUpdate) };
}
public void Update(ToolCallMessageUpdate update)
{
// firstly, valid if the update is from the same agent
if (update.From != this.From)
{
throw new System.ArgumentException("From mismatch", nameof(update));
}
// if update.FunctionName exists in the tool calls, update the function arguments
var toolCall = this.ToolCalls.FirstOrDefault(tc => tc.FunctionName == update.FunctionName);
if (toolCall is not null)
{
toolCall.FunctionArguments += update.FunctionArgumentUpdate;
}
else
{
this.ToolCalls.Add(new ToolCall(update.FunctionName, update.FunctionArgumentUpdate));
}
}
public IList<ToolCall> ToolCalls { get; set; }
public string? From { get; set; }
public override string ToString()
{
var sb = new StringBuilder();
sb.Append($"ToolCallMessage({this.From})");
foreach (var toolCall in this.ToolCalls)
{
sb.Append($"\n\t{toolCall}");
}
return sb.ToString();
}
}
public class ToolCallMessageUpdate : IStreamingMessage
{
public ToolCallMessageUpdate(string functionName, string functionArgumentUpdate, string? from = null)
{
this.From = from;
this.FunctionName = functionName;
this.FunctionArgumentUpdate = functionArgumentUpdate;
}
public string? From { get; set; }
public string FunctionName { get; set; }
public string FunctionArgumentUpdate { get; set; }
}

View File

@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ToolCallResultMessage.cs
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace AutoGen.Core;
public class ToolCallResultMessage : IMessage
{
public ToolCallResultMessage(IEnumerable<ToolCall> toolCalls, string? from = null)
{
this.From = from;
this.ToolCalls = toolCalls.ToList();
}
public ToolCallResultMessage(string result, string functionName, string functionArgs, string? from = null)
{
this.From = from;
var toolCall = new ToolCall(functionName, functionArgs);
toolCall.Result = result;
this.ToolCalls = [toolCall];
}
/// <summary>
/// The original tool call message
/// </summary>
public IList<ToolCall> ToolCalls { get; set; }
public string? From { get; set; }
public override string ToString()
{
var sb = new StringBuilder();
sb.Append($"ToolCallResultMessage({this.From})");
foreach (var toolCall in this.ToolCalls)
{
sb.Append($"\n\t{toolCall}");
}
return sb.ToString();
}
private void Validate()
{
// each tool call must have a result
foreach (var toolCall in this.ToolCalls)
{
if (string.IsNullOrEmpty(toolCall.Result))
{
throw new System.ArgumentException($"The tool call {toolCall} does not have a result");
}
}
}
}

View File

@@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DelegateMiddleware.cs
using System;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
internal class DelegateMiddleware : IMiddleware
{
/// <summary>
/// middleware delegate. Call into the next function to continue the execution of the next middleware. Otherwise, short cut the middleware execution.
/// </summary>
/// <param name="cancellationToken">cancellation token</param>
public delegate Task<IMessage> MiddlewareDelegate(
MiddlewareContext context,
IAgent agent,
CancellationToken cancellationToken);
private readonly MiddlewareDelegate middlewareDelegate;
public DelegateMiddleware(string? name, Func<MiddlewareContext, IAgent, CancellationToken, Task<IMessage>> middlewareDelegate)
{
this.Name = name;
this.middlewareDelegate = async (context, agent, cancellationToken) =>
{
return await middlewareDelegate(context, agent, cancellationToken);
};
}
public string? Name { get; }
public Task<IMessage> InvokeAsync(
MiddlewareContext context,
IAgent agent,
CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var options = context.Options;
return this.middlewareDelegate(context, agent, cancellationToken);
}
}

View File

@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DelegateStreamingMiddleware.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
internal class DelegateStreamingMiddleware : IStreamingMiddleware
{
public delegate Task<IAsyncEnumerable<IStreamingMessage>> MiddlewareDelegate(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken);
private readonly MiddlewareDelegate middlewareDelegate;
public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDelegate)
{
this.Name = name;
this.middlewareDelegate = middlewareDelegate;
}
public string? Name { get; }
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var options = context.Options;
return this.middlewareDelegate(context, agent, cancellationToken);
}
}

View File

@@ -0,0 +1,178 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionCallMiddleware.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// The middleware that process function call message that both send to an agent or reply from an agent.
/// <para>If the last message is <see cref="ToolCallMessage"/> and the tool calls is available in this middleware's function map,
/// the tools from the last message will be invoked and a <see cref="ToolCallResultMessage"/> will be returned. In this situation,
/// the inner agent will be short-cut and won't be invoked.</para>
/// <para>Otherwise, the message will be sent to the inner agent. In this situation</para>
/// <para>if the reply from the inner agent is <see cref="ToolCallMessage"/>,
/// and the tool calls is available in this middleware's function map, the tools from the reply will be invoked,
/// and a <see cref="AggregateMessage{TMessage1, TMessage2}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/>"/>
/// will be returned.
/// </para>
/// <para>If the reply from the inner agent is <see cref="ToolCallMessage"/> but the tool calls is not available in this middleware's function map,
/// or the reply from the inner agent is not <see cref="ToolCallMessage"/>, the original reply from the inner agent will be returned.</para>
/// <para>
/// When used as a streaming middleware, if the streaming reply from the inner agent is <see cref="ToolCallMessageUpdate"/> or <see cref="TextMessageUpdate"/>,
/// This middleware will update the message accordingly and invoke the function if the tool call is available in this middleware's function map.
/// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function.
/// </para>
/// </summary>
public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware
{
private readonly IEnumerable<FunctionContract>? functions;
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
public FunctionCallMiddleware(
IEnumerable<FunctionContract>? functions = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
string? name = null)
{
this.Name = name ?? nameof(FunctionCallMiddleware);
this.functions = functions;
this.functionMap = functionMap;
}
public string? Name { get; }
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var lastMessage = context.Messages.Last();
if (lastMessage is ToolCallMessage toolCallMessage)
{
return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent);
}
// combine functions
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();
var reply = await agent.GenerateReplyAsync(context.Messages, options, cancellationToken);
// if the reply is a function call message plus the function's name is available in function map, invoke the function and return the result instead of sending to the agent.
if (reply is ToolCallMessage toolCallMsg)
{
return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
// for all other messages, just return the reply from the agent.
return reply;
}
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
{
return Task.FromResult(this.StreamingInvokeAsync(context, agent, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> StreamingInvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
var lastMessage = context.Messages.Last();
if (lastMessage is ToolCallMessage toolCallMessage)
{
yield return await this.InvokeToolCallMessagesBeforeInvokingAgentAsync(toolCallMessage, agent);
}
// combine functions
var options = new GenerateReplyOptions(context.Options ?? new GenerateReplyOptions());
var combinedFunctions = this.functions?.Concat(options.Functions ?? []) ?? options.Functions;
options.Functions = combinedFunctions?.ToArray();
IStreamingMessage? initMessage = default;
await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken))
{
if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null)
{
if (initMessage is null)
{
initMessage = new ToolCallMessage(toolCallMessageUpdate);
}
else if (initMessage is ToolCallMessage toolCall)
{
toolCall.Update(toolCallMessageUpdate);
}
else
{
throw new InvalidOperationException("The first message is ToolCallMessage, but the update message is not ToolCallMessageUpdate");
}
}
else
{
yield return message;
}
}
if (initMessage is ToolCallMessage toolCallMsg)
{
yield return await this.InvokeToolCallMessagesAfterInvokingAgentAsync(toolCallMsg, agent);
}
}
private async Task<ToolCallResultMessage> InvokeToolCallMessagesBeforeInvokingAgentAsync(ToolCallMessage toolCallMessage, IAgent agent)
{
var toolCallResult = new List<ToolCall>();
var toolCalls = toolCallMessage.ToolCalls;
foreach (var toolCall in toolCalls)
{
var functionName = toolCall.FunctionName;
var functionArguments = toolCall.FunctionArguments;
if (this.functionMap?.TryGetValue(functionName, out var func) is true)
{
var result = await func(functionArguments);
toolCallResult.Add(new ToolCall(functionName, functionArguments, result));
}
else if (this.functionMap is not null)
{
var errorMessage = $"Function {functionName} is not available. Available functions are: {string.Join(", ", this.functionMap.Select(f => f.Key))}";
toolCallResult.Add(new ToolCall(functionName, functionArguments, errorMessage));
}
else
{
throw new InvalidOperationException("FunctionMap is not available");
}
}
return new ToolCallResultMessage(toolCallResult, from: agent.Name);
}
private async Task<IMessage> InvokeToolCallMessagesAfterInvokingAgentAsync(ToolCallMessage toolCallMsg, IAgent agent)
{
var toolCallsReply = toolCallMsg.ToolCalls;
var toolCallResult = new List<ToolCall>();
foreach (var toolCall in toolCallsReply)
{
var fName = toolCall.FunctionName;
var fArgs = toolCall.FunctionArguments;
if (this.functionMap?.TryGetValue(fName, out var func) is true)
{
var result = await func(fArgs);
toolCallResult.Add(new ToolCall(fName, fArgs, result));
}
}
if (toolCallResult.Count() > 0)
{
var toolCallResultMessage = new ToolCallResultMessage(toolCallResult, from: agent.Name);
return new AggregateMessage<ToolCallMessage, ToolCallResultMessage>(toolCallMsg, toolCallResultMessage, from: agent.Name);
}
else
{
return toolCallMsg;
}
}
}

View File

@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IMiddleware.cs
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// The middleware interface
/// </summary>
public interface IMiddleware
{
/// <summary>
/// the name of the middleware
/// </summary>
public string? Name { get; }
/// <summary>
/// The method to invoke the middleware
/// </summary>
public Task<IMessage> InvokeAsync(
MiddlewareContext context,
IAgent agent,
CancellationToken cancellationToken = default);
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// IStreamingMiddleware.cs
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// The streaming middleware interface
/// </summary>
public interface IStreamingMiddleware
{
public string? Name { get; }
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default);
}

View File

@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MiddlewareContext.cs
using System.Collections.Generic;
namespace AutoGen.Core;
public class MiddlewareContext
{
public MiddlewareContext(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options)
{
this.Messages = messages;
this.Options = options;
}
/// <summary>
/// Messages to send to the agent
/// </summary>
public IEnumerable<IMessage> Messages { get; }
/// <summary>
/// Options to generate the reply
/// </summary>
public GenerateReplyOptions? Options { get; }
}

View File

@@ -0,0 +1,87 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// PrintMessageMiddleware.cs
using System;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen.Core;
/// <summary>
/// The middleware that prints the reply from agent to the console.
/// </summary>
public class PrintMessageMiddleware : IMiddleware
{
public string? Name => nameof(PrintMessageMiddleware);
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
if (agent is IStreamingAgent streamingAgent)
{
IMessage? recentUpdate = null;
await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken))
{
if (message is TextMessageUpdate textMessageUpdate)
{
if (recentUpdate is null)
{
// Print from: xxx
Console.WriteLine($"from: {textMessageUpdate.From}");
recentUpdate = new TextMessage(textMessageUpdate);
Console.Write(textMessageUpdate.Content);
}
else if (recentUpdate is TextMessage recentTextMessage)
{
// Print the content of the message
Console.Write(textMessageUpdate.Content);
recentTextMessage.Update(textMessageUpdate);
}
else
{
throw new InvalidOperationException("The recent update is not a TextMessage");
}
}
else if (message is ToolCallMessageUpdate toolCallUpdate)
{
if (recentUpdate is null)
{
recentUpdate = new ToolCallMessage(toolCallUpdate);
}
else if (recentUpdate is ToolCallMessage recentToolCallMessage)
{
recentToolCallMessage.Update(toolCallUpdate);
}
else
{
throw new InvalidOperationException("The recent update is not a ToolCallMessage");
}
}
else if (message is IMessage imessage)
{
recentUpdate = imessage;
}
else
{
throw new InvalidOperationException("The message is not a valid message");
}
}
Console.WriteLine();
if (recentUpdate is not null && recentUpdate is not TextMessage)
{
Console.WriteLine(recentUpdate.FormatMessage());
}
return recentUpdate ?? throw new InvalidOperationException("The message is not a valid message");
}
else
{
var reply = await agent.GenerateReplyAsync(context.Messages, context.Options, cancellationToken);
var formattedMessages = reply.FormatMessage();
Console.WriteLine(formattedMessages);
return reply;
}
}
}

View File

@@ -0,0 +1,40 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<RootNamespace>AutoGen.DotnetInteractive</RootNamespace>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.DotnetInteractive</Title>
<Description>
Dotnet interactive integration for AutoGen agents
</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.DotNet.Interactive.VisualStudio" Version="$(MicrosoftDotnetInteractive)" />
</ItemGroup>
<ItemGroup>
<EmbeddedResource Include="dotnet-tools.json" />
<EmbeddedResource Include="RestoreInteractive.config" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="$(AzureOpenAIVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen\AutoGen.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,278 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DotnetInteractiveFunction.cs
using System.Text;
using System.Text.Json;
using Azure.AI.OpenAI;
using Microsoft.DotNet.Interactive.Documents;
using Microsoft.DotNet.Interactive.Documents.Jupyter;
namespace AutoGen.DotnetInteractive;
public class DotnetInteractiveFunction : IDisposable
{
private readonly InteractiveService? _interactiveService = null;
private string? _notebookPath;
private readonly KernelInfoCollection _kernelInfoCollection = new KernelInfoCollection();
public DotnetInteractiveFunction(InteractiveService interactiveService, string? notebookPath = null, bool continueFromExistingNotebook = false)
{
this._interactiveService = interactiveService;
this._notebookPath = notebookPath;
this._kernelInfoCollection.Add(new KernelInfo("csharp"));
this._kernelInfoCollection.Add(new KernelInfo("markdown"));
if (this._notebookPath != null)
{
if (continueFromExistingNotebook == false)
{
// remove existing notebook
if (File.Exists(this._notebookPath))
{
File.Delete(this._notebookPath);
}
var document = new InteractiveDocument();
using var stream = File.OpenWrite(_notebookPath);
Notebook.Write(document, stream, this._kernelInfoCollection);
stream.Flush();
stream.Dispose();
}
else if (continueFromExistingNotebook == true && File.Exists(this._notebookPath))
{
// load existing notebook
using var readStream = File.OpenRead(this._notebookPath);
var document = Notebook.Read(readStream, this._kernelInfoCollection);
foreach (var cell in document.Elements)
{
if (cell.KernelName == "csharp")
{
var code = cell.Contents;
this._interactiveService.SubmitCSharpCodeAsync(code, default).Wait();
}
}
}
else
{
// create an empty notebook
var document = new InteractiveDocument();
using var stream = File.OpenWrite(_notebookPath);
Notebook.Write(document, stream, this._kernelInfoCollection);
stream.Flush();
stream.Dispose();
}
}
}
/// <summary>
/// Run existing dotnet code from message. Don't modify the code, run it as is.
/// </summary>
/// <param name="code">code.</param>
public async Task<string> RunCode(string code)
{
if (this._interactiveService == null)
{
throw new Exception("InteractiveService is not initialized.");
}
var result = await this._interactiveService.SubmitCSharpCodeAsync(code, default);
if (result != null)
{
// if result contains Error, return entire message
if (result.StartsWith("Error:"))
{
return result;
}
// add cell if _notebookPath is not null
if (this._notebookPath != null)
{
await AddCellAsync(code, "csharp");
}
// if result is over 100 characters, only return the first 100 characters.
if (result.Length > 100)
{
result = result.Substring(0, 100) + " (...too long to present)";
return result;
}
return result;
}
// add cell if _notebookPath is not null
if (this._notebookPath != null)
{
await AddCellAsync(code, "csharp");
}
return "Code run successfully. no output is available.";
}
/// <summary>
/// Install nuget packages.
/// </summary>
/// <param name="nugetPackages">nuget package to install.</param>
public async Task<string> InstallNugetPackages(string[] nugetPackages)
{
if (this._interactiveService == null)
{
throw new Exception("InteractiveService is not initialized.");
}
var codeSB = new StringBuilder();
foreach (var nuget in nugetPackages ?? Array.Empty<string>())
{
var nugetInstallCommand = $"#r \"nuget:{nuget}\"";
codeSB.AppendLine(nugetInstallCommand);
await this._interactiveService.SubmitCSharpCodeAsync(nugetInstallCommand, default);
}
var code = codeSB.ToString();
if (this._notebookPath != null)
{
await AddCellAsync(code, "csharp");
}
var sb = new StringBuilder();
sb.AppendLine("Installed nuget packages:");
foreach (var nuget in nugetPackages ?? Array.Empty<string>())
{
sb.AppendLine($"- {nuget}");
}
return sb.ToString();
}
private async Task AddCellAsync(string cellContent, string kernelName)
{
if (!File.Exists(this._notebookPath))
{
using var stream = File.OpenWrite(this._notebookPath);
Notebook.Write(new InteractiveDocument(), stream, this._kernelInfoCollection);
stream.Dispose();
}
using var readStream = File.OpenRead(this._notebookPath);
var document = Notebook.Read(readStream, this._kernelInfoCollection);
readStream.Dispose();
var cell = new InteractiveDocumentElement(cellContent, kernelName);
document.Add(cell);
using var writeStream = File.OpenWrite(this._notebookPath);
Notebook.Write(document, writeStream, this._kernelInfoCollection);
// sleep 3 seconds
await Task.Delay(3000);
writeStream.Flush();
writeStream.Dispose();
}
private class RunCodeSchema
{
public string code { get; set; } = string.Empty;
}
public Task<string> RunCodeWrapper(string arguments)
{
var schema = JsonSerializer.Deserialize<RunCodeSchema>(
arguments,
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});
return RunCode(schema!.code);
}
public FunctionDefinition RunCodeFunction
{
get => new FunctionDefinition
{
Name = @"RunCode",
Description = """
Run existing dotnet code from message. Don't modify the code, run it as is.
""",
Parameters = BinaryData.FromObjectAsJson(new
{
Type = "object",
Properties = new
{
code = new
{
Type = @"string",
Description = @"code.",
},
},
Required = new[]
{
"code",
},
},
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
})
};
}
private class InstallNugetPackagesSchema
{
public string[] nugetPackages { get; set; } = Array.Empty<string>();
}
public Task<string> InstallNugetPackagesWrapper(string arguments)
{
var schema = JsonSerializer.Deserialize<InstallNugetPackagesSchema>(
arguments,
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});
return InstallNugetPackages(schema!.nugetPackages);
}
public FunctionDefinition InstallNugetPackagesFunction
{
get => new FunctionDefinition
{
Name = @"InstallNugetPackages",
Description = """
Install nuget packages.
""",
Parameters = BinaryData.FromObjectAsJson(new
{
Type = "object",
Properties = new
{
nugetPackages = new
{
Type = @"array",
Items = new
{
Type = @"string",
},
Description = @"nuget package to install.",
},
},
Required = new[]
{
"nugetPackages",
},
},
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
})
};
}
public void Dispose()
{
this._interactiveService?.Dispose();
}
}

View File

@@ -0,0 +1,83 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AgentExtension.cs
using System.Text;
namespace AutoGen.DotnetInteractive;
public static class AgentExtension
{
/// <summary>
/// Register an AutoReply hook to run dotnet code block from message.
/// This hook will first detect if there's any dotnet code block (e.g. ```csharp and ```) in the most recent message.
/// if there's any, it will run the code block and send the result back as reply.
/// </summary>
/// <param name="agent">agent</param>
/// <param name="interactiveService">interactive service</param>
/// <param name="codeBlockPrefix">code block prefix</param>
/// <param name="codeBlockSuffix">code block suffix</param>
/// <param name="maximumOutputToKeep">maximum output to keep</param>
/// <example>
/// <![CDATA[
/// [!code-csharp[Example04_Dynamic_GroupChat_Coding_Task](~/../sample/AutoGen.BasicSamples/Example04_Dynamic_GroupChat_Coding_Task.cs)]
/// ]]>
/// </example>
public static IAgent RegisterDotnetCodeBlockExectionHook(
this IAgent agent,
InteractiveService interactiveService,
string codeBlockPrefix = "```csharp",
string codeBlockSuffix = "```",
int maximumOutputToKeep = 500)
{
return agent.RegisterReply(async (msgs, ct) =>
{
var lastMessage = msgs.LastOrDefault();
if (lastMessage == null || lastMessage.GetContent() is null)
{
return null;
}
// retrieve all code blocks from last message
var codeBlocks = lastMessage.GetContent()!.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries);
if (codeBlocks.Length <= 0)
{
return null;
}
// run code blocks
var result = new StringBuilder();
var i = 0;
result.AppendLine(@$"// [DOTNET_CODE_BLOCK_EXECUTION]");
foreach (var codeBlock in codeBlocks)
{
var codeBlockIndex = codeBlock.IndexOf(codeBlockSuffix);
if (codeBlockIndex == -1)
{
continue;
}
// remove code block suffix
var code = codeBlock.Substring(0, codeBlockIndex).Trim();
if (code.Length == 0)
{
continue;
}
var codeResult = await interactiveService.SubmitCSharpCodeAsync(code, ct);
if (codeResult != null)
{
result.AppendLine(@$"### Executing result for code block {i++}");
result.AppendLine(codeResult);
result.AppendLine("### End of executing result ###");
}
}
if (result.Length <= maximumOutputToKeep)
{
maximumOutputToKeep = result.Length;
}
return new TextMessage(Role.Assistant, result.ToString().Substring(0, maximumOutputToKeep), from: agent.Name);
});
}
}

View File

@@ -0,0 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GlobalUsing.cs
global using AutoGen.Core;

View File

@@ -0,0 +1,261 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// InteractiveService.cs
using System.Diagnostics;
using System.Reactive.Linq;
using System.Reflection;
using Microsoft.DotNet.Interactive;
using Microsoft.DotNet.Interactive.App.Connection;
using Microsoft.DotNet.Interactive.Commands;
using Microsoft.DotNet.Interactive.Connection;
using Microsoft.DotNet.Interactive.Events;
using Microsoft.DotNet.Interactive.Utility;
namespace AutoGen.DotnetInteractive;
public class InteractiveService : IDisposable
{
private Kernel? kernel = null;
private Process? process = null;
private bool disposedValue;
private const string DotnetInteractiveToolNotInstallMessage = "Cannot find a tool in the manifest file that has a command named 'dotnet-interactive'.";
//private readonly ProcessJobTracker jobTracker = new ProcessJobTracker();
private string installingDirectory;
public event EventHandler<DisplayEvent>? DisplayEvent;
public event EventHandler<string>? Output;
public event EventHandler<CommandFailed>? CommandFailed;
public event EventHandler<HoverTextProduced>? HoverTextProduced;
/// <summary>
/// Create an instance of InteractiveService
/// </summary>
/// <param name="installingDirectory">dotnet interactive installing directory</param>
public InteractiveService(string installingDirectory)
{
this.installingDirectory = installingDirectory;
}
public async Task<bool> StartAsync(string workingDirectory, CancellationToken ct = default)
{
this.kernel = await this.CreateKernelAsync(workingDirectory, ct);
return true;
}
public async Task<string?> SubmitCommandAsync(KernelCommand cmd, CancellationToken ct)
{
if (this.kernel == null)
{
throw new Exception("Kernel is not running");
}
try
{
var res = await this.kernel.SendAndThrowOnCommandFailedAsync(cmd, ct);
var events = res.Events;
var displayValues = events.Where(x => x is StandardErrorValueProduced || x is StandardOutputValueProduced || x is ReturnValueProduced)
.SelectMany(x => (x as DisplayEvent)!.FormattedValues);
if (displayValues is null || displayValues.Count() == 0)
{
return null;
}
return string.Join("\n", displayValues.Select(x => x.Value));
}
catch (Exception ex)
{
return $"Error: {ex.Message}";
}
}
public async Task<string?> SubmitPowershellCodeAsync(string code, CancellationToken ct)
{
var command = new SubmitCode(code, targetKernelName: "pwsh");
return await this.SubmitCommandAsync(command, ct);
}
public async Task<string?> SubmitCSharpCodeAsync(string code, CancellationToken ct)
{
var command = new SubmitCode(code, targetKernelName: "csharp");
return await this.SubmitCommandAsync(command, ct);
}
private async Task<Kernel> CreateKernelAsync(string workingDirectory, CancellationToken ct = default)
{
try
{
var url = KernelHost.CreateHostUriForCurrentProcessId();
var compositeKernel = new CompositeKernel("cbcomposite");
var cmd = new string[]
{
"dotnet",
"tool",
"run",
"dotnet-interactive",
$"[cb-{Process.GetCurrentProcess().Id}]",
"stdio",
//"--default-kernel",
//"csharp",
"--working-dir",
$@"""{workingDirectory}""",
};
var connector = new StdIoKernelConnector(
cmd,
"root-proxy",
url,
new DirectoryInfo(workingDirectory));
// Start the dotnet-interactive tool and get a proxy for the root composite kernel therein.
using var rootProxyKernel = await connector.CreateRootProxyKernelAsync().ConfigureAwait(false);
// Get proxies for each subkernel present inside the dotnet-interactive tool.
var requestKernelInfoCommand = new RequestKernelInfo(rootProxyKernel.KernelInfo.RemoteUri);
var result =
await rootProxyKernel.SendAsync(
requestKernelInfoCommand,
ct).ConfigureAwait(false);
var subKernels = result.Events.OfType<KernelInfoProduced>();
foreach (var kernelInfoProduced in result.Events.OfType<KernelInfoProduced>())
{
var kernelInfo = kernelInfoProduced.KernelInfo;
if (kernelInfo is not null && !kernelInfo.IsProxy && !kernelInfo.IsComposite)
{
var proxyKernel = await connector.CreateProxyKernelAsync(kernelInfo).ConfigureAwait(false);
proxyKernel.SetUpValueSharingIfSupported();
compositeKernel.Add(proxyKernel);
}
}
//compositeKernel.DefaultKernelName = "csharp";
compositeKernel.Add(rootProxyKernel);
compositeKernel.KernelEvents.Subscribe(this.OnKernelDiagnosticEventReceived);
return compositeKernel;
}
catch (CommandLineInvocationException ex) when (ex.Message.Contains("Cannot find a tool in the manifest file that has a command named 'dotnet-interactive'"))
{
var success = this.RestoreDotnetInteractive();
if (success)
{
return await this.CreateKernelAsync(workingDirectory, ct);
}
throw;
}
}
private void OnKernelDiagnosticEventReceived(KernelEvent ke)
{
this.WriteLine("Receive data from kernel");
this.WriteLine(KernelEventEnvelope.Serialize(ke));
switch (ke)
{
case DisplayEvent de:
this.DisplayEvent?.Invoke(this, de);
break;
case CommandFailed cf:
this.CommandFailed?.Invoke(this, cf);
break;
case HoverTextProduced cf:
this.HoverTextProduced?.Invoke(this, cf);
break;
}
}
private void WriteLine(string data)
{
this.Output?.Invoke(this, data);
}
private bool RestoreDotnetInteractive()
{
this.WriteLine("Restore dotnet interactive tool");
// write RestoreInteractive.config from embedded resource to this.workingDirectory
var assembly = Assembly.GetAssembly(typeof(InteractiveService))!;
var resourceName = "AutoGen.DotnetInteractive.RestoreInteractive.config";
using (var stream = assembly.GetManifestResourceStream(resourceName)!)
using (var fileStream = File.Create(Path.Combine(this.installingDirectory, "RestoreInteractive.config")))
{
stream.CopyTo(fileStream);
}
// write dotnet-tool.json from embedded resource to this.workingDirectory
resourceName = "AutoGen.DotnetInteractive.dotnet-tools.json";
using (var stream2 = assembly.GetManifestResourceStream(resourceName)!)
using (var fileStream2 = File.Create(Path.Combine(this.installingDirectory, "dotnet-tools.json")))
{
stream2.CopyTo(fileStream2);
}
var psi = new ProcessStartInfo
{
FileName = "dotnet",
Arguments = $"tool restore --configfile RestoreInteractive.config",
WorkingDirectory = this.installingDirectory,
RedirectStandardInput = true,
RedirectStandardOutput = true,
RedirectStandardError = true,
UseShellExecute = false,
CreateNoWindow = true,
};
using var process = new Process { StartInfo = psi };
process.OutputDataReceived += this.PrintProcessOutput;
process.ErrorDataReceived += this.PrintProcessOutput;
process.Start();
process.BeginErrorReadLine();
process.BeginOutputReadLine();
process.WaitForExit();
return process.ExitCode == 0;
}
private void PrintProcessOutput(object sender, DataReceivedEventArgs e)
{
if (!string.IsNullOrEmpty(e.Data))
{
this.WriteLine(e.Data);
}
}
public bool IsRunning()
{
return this.kernel != null;
}
protected virtual void Dispose(bool disposing)
{
if (!disposedValue)
{
if (disposing)
{
this.kernel?.Dispose();
if (this.process != null)
{
this.process.Kill();
this.process.Dispose();
}
}
disposedValue = true;
}
}
public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}

View File

@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<packageSources>
<clear />
<add key="nuget.org"
value="https://api.nuget.org/v3/index.json" />
</packageSources>
<disabledPackageSources />
</configuration>

View File

@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Utils.cs
using System.Collections;
using System.Collections.Immutable;
using Microsoft.DotNet.Interactive;
using Microsoft.DotNet.Interactive.Commands;
using Microsoft.DotNet.Interactive.Connection;
using Microsoft.DotNet.Interactive.Events;
public static class ObservableExtensions
{
public static SubscribedList<T> ToSubscribedList<T>(this IObservable<T> source)
{
return new SubscribedList<T>(source);
}
}
public static class KernelExtensions
{
internal static void SetUpValueSharingIfSupported(this ProxyKernel proxyKernel)
{
var supportedCommands = proxyKernel.KernelInfo.SupportedKernelCommands;
if (supportedCommands.Any(d => d.Name == nameof(RequestValue)) &&
supportedCommands.Any(d => d.Name == nameof(SendValue)))
{
proxyKernel.UseValueSharing();
}
}
internal static async Task<KernelCommandResult> SendAndThrowOnCommandFailedAsync(
this Kernel kernel,
KernelCommand command,
CancellationToken cancellationToken)
{
var result = await kernel.SendAsync(command, cancellationToken);
result.ThrowOnCommandFailed();
return result;
}
private static void ThrowOnCommandFailed(this KernelCommandResult result)
{
var failedEvents = result.Events.OfType<CommandFailed>();
if (!failedEvents.Any())
{
return;
}
if (failedEvents.Skip(1).Any())
{
var innerExceptions = failedEvents.Select(f => f.GetException());
throw new AggregateException(innerExceptions);
}
else
{
throw failedEvents.Single().GetException();
}
}
private static Exception GetException(this CommandFailed commandFailedEvent)
=> new Exception(commandFailedEvent.Message);
}
public class SubscribedList<T> : IReadOnlyList<T>, IDisposable
{
private ImmutableArray<T> _list = ImmutableArray<T>.Empty;
private readonly IDisposable _subscription;
public SubscribedList(IObservable<T> source)
{
_subscription = source.Subscribe(x => _list = _list.Add(x));
}
public IEnumerator<T> GetEnumerator()
{
return ((IEnumerable<T>)_list).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public int Count => _list.Length;
public T this[int index] => _list[index];
public void Dispose() => _subscription.Dispose();
}

View File

@@ -0,0 +1,12 @@
{
"version": 1,
"isRoot": true,
"tools": {
"Microsoft.dotnet-interactive": {
"version": "1.0.431302",
"commands": [
"dotnet-interactive"
]
}
}
}

View File

@@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>AutoGen.LMStudio</RootNamespace>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.LMStudio</Title>
<Description>
Provide support for consuming LMStudio openai-like API service in AutoGen
</Description>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
<ProjectReference Include="..\AutoGen.OpenAI\AutoGen.OpenAI.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GlobalUsing.cs
global using AutoGen.Core;

View File

@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// LMStudioAgent.cs
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using AutoGen.OpenAI;
using Azure.AI.OpenAI;
using Azure.Core.Pipeline;
using Azure.Core;
namespace AutoGen.LMStudio;
/// <summary>
/// agent that consumes local server from LM Studio
/// </summary>
/// <example>
/// [!code-csharp[LMStudioAgent](../../sample/AutoGen.BasicSamples/Example08_LMStudio.cs?name=lmstudio_example_1)]
/// </example>
public class LMStudioAgent : IAgent
{
private readonly GPTAgent innerAgent;
public LMStudioAgent(
string name,
LMStudioConfig config,
string systemMessage = "You are a helpful AI assistant",
float temperature = 0.7f,
int maxTokens = 1024,
IEnumerable<FunctionDefinition>? functions = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null)
{
var client = ConfigOpenAIClientForLMStudio(config);
innerAgent = new GPTAgent(
name: name,
systemMessage: systemMessage,
openAIClient: client,
modelName: "llm", // model name doesn't matter for LM Studio
temperature: temperature,
maxTokens: maxTokens,
functions: functions,
functionMap: functionMap);
}
public string Name => innerAgent.Name;
public Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
System.Threading.CancellationToken cancellationToken = default)
{
return innerAgent.GenerateReplyAsync(messages, options, cancellationToken);
}
private OpenAIClient ConfigOpenAIClientForLMStudio(LMStudioConfig config)
{
// create uri from host and port
var uri = config.Uri;
var accessToken = new AccessToken(string.Empty, DateTimeOffset.Now.AddDays(180));
var tokenCredential = DelegatedTokenCredential.Create((_, _) => accessToken);
var openAIClient = new OpenAIClient(uri, tokenCredential);
// remove authenication header from pipeline
var pipeline = HttpPipelineBuilder.Build(
new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2022_12_01),
Array.Empty<HttpPipelinePolicy>(),
[],
new ResponseClassifier());
// use reflection to override _pipeline field
var field = typeof(OpenAIClient).GetField("_pipeline", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
field.SetValue(openAIClient, pipeline);
// use reflection to set _isConfiguredForAzureOpenAI to false
var isConfiguredForAzureOpenAIField = typeof(OpenAIClient).GetField("_isConfiguredForAzureOpenAI", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
isConfiguredForAzureOpenAIField.SetValue(openAIClient, false);
return openAIClient;
}
}

View File

@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// LMStudioConfig.cs
using System;
/// <summary>
/// Add support for consuming openai-like API from LM Studio
/// </summary>
public class LMStudioConfig : ILLMConfig
{
public LMStudioConfig(string host, int port, int version = 1)
{
this.Host = host;
this.Port = port;
this.Version = version;
}
public string Host { get; }
public int Port { get; }
public int Version { get; }
public Uri Uri => new Uri($"http://{Host}:{Port}/v{Version}");
}

View File

@@ -0,0 +1,31 @@
## AutoGen.LMStudio
This package provides support for consuming openai-like API from LMStudio local server.
## Installation
To use `AutoGen.LMStudio`, add the following package to your `.csproj` file:
```xml
<ItemGroup>
<PackageReference Include="AutoGen.LMStudio" Version="AUTOGEN_VERSION" />
</ItemGroup>
```
## Usage
```csharp
using AutoGen.LMStudio;
var localServerEndpoint = "localhost";
var port = 5000;
var lmStudioConfig = new LMStudioConfig(localServerEndpoint, port);
var agent = new LMStudioAgent(
name: "agent",
systemMessage: "You are an agent that help user to do some tasks.",
lmStudioConfig: lmStudioConfig)
.RegisterPrintMessage(); // register a hook to print message nicely to console
await agent.SendAsync("Can you write a piece of C# code to calculate 100th of fibonacci?");
```
## Update history
### Update on 0.0.7 (2024-02-11)
- Add `LMStudioAgent` to support consuming openai-like API from LMStudio local server.

View File

@@ -0,0 +1,133 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MistralClientAgent.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
using AutoGen.Mistral.Extension;
namespace AutoGen.Mistral;
/// <summary>
/// Mistral client agent.
///
/// <para>This agent supports the following input message types:</para>
/// <list type="bullet">
/// <para><see cref="MessageEnvelope{T}"/> where T is <see cref="ChatMessage"/></para>
/// </list>
///
/// <para>This agent returns the following message types:</para>
/// <list type="bullet">
/// <para><see cref="MessageEnvelope{T}"/> where T is <see cref="ChatCompletionResponse"/></para>
/// </list>
///
/// You can register this agent with <see cref="MistralAgentExtension.RegisterMessageConnector(AutoGen.Mistral.MistralClientAgent, AutoGen.Mistral.MistralChatMessageConnector?)"/>
/// to support more AutoGen message types.
/// </summary>
public class MistralClientAgent : IStreamingAgent
{
private readonly MistralClient _client;
private readonly string _systemMessage;
private readonly string _model;
private readonly int? _randomSeed;
private readonly bool _jsonOutput = false;
private ToolChoiceEnum? _toolChoice;
/// <summary>
/// Create a new instance of <see cref="MistralClientAgent"/>.
/// </summary>
/// <param name="client"><see cref="MistralClient"/></param>
/// <param name="name">the name of this agent</param>
/// <param name="model">the mistral model id.</param>
/// <param name="systemMessage">system message.</param>
/// <param name="randomSeed">the seed to generate output.</param>
/// <param name="toolChoice">tool choice strategy.</param>
/// <param name="jsonOutput">use json output.</param>
public MistralClientAgent(
MistralClient client,
string name,
string model,
string systemMessage = "You are a helpful AI assistant",
int? randomSeed = null,
ToolChoiceEnum? toolChoice = null,
bool jsonOutput = false)
{
_client = client;
Name = name;
_systemMessage = systemMessage;
_model = model;
_randomSeed = randomSeed;
_jsonOutput = jsonOutput;
_toolChoice = toolChoice;
}
public string Name { get; }
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = await _client.CreateChatCompletionsAsync(request);
return new MessageEnvelope<ChatCompletionResponse>(response, from: this.Name);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var request = BuildChatRequest(messages, options);
var response = _client.StreamingChatCompletionsAsync(request);
return ProcessMessage(response);
}
private async IAsyncEnumerable<IMessage> ProcessMessage(IAsyncEnumerable<ChatCompletionResponse> response)
{
await foreach (var content in response)
{
yield return new MessageEnvelope<ChatCompletionResponse>(content, from: this.Name);
}
}
private ChatCompletionRequest BuildChatRequest(IEnumerable<IMessage> messages, GenerateReplyOptions? options)
{
var chatHistory = BuildChatHistory(messages);
var chatRequest = new ChatCompletionRequest(model: _model, messages: chatHistory.ToList(), temperature: options?.Temperature, randomSeed: _randomSeed)
{
MaxTokens = options?.MaxToken,
ResponseFormat = _jsonOutput ? new ResponseFormat() { ResponseFormatType = "json_object" } : null,
};
if (options?.Functions != null)
{
chatRequest.Tools = options.Functions.Select(f => new FunctionTool(f.ToMistralFunctionDefinition())).ToList();
chatRequest.ToolChoice = _toolChoice ?? ToolChoiceEnum.Auto;
}
return chatRequest;
}
private IEnumerable<ChatMessage> BuildChatHistory(IEnumerable<IMessage> messages)
{
var history = messages.Select(m => m switch
{
IMessage<ChatMessage> chatMessage => chatMessage.Content,
_ => throw new ArgumentException("Invalid message type")
});
// if there's no system message in the history, add one to the beginning
if (!history.Any(m => m.Role == ChatMessage.RoleEnum.System))
{
history = new[] { new ChatMessage(ChatMessage.RoleEnum.System, _systemMessage) }.Concat(history);
}
return history;
}
}

View File

@@ -0,0 +1,27 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>AutoGen.Mistral</RootNamespace>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.Mistral</Title>
<Description>
Provide support for consuming Mistral model in AutoGen
</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.Memory.Data" Version="8.0.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// JsonPropertyNameEnumConverter.cs
using System;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
internal class JsonPropertyNameEnumConverter<T> : JsonConverter<T> where T : struct, Enum
{
public override T Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
string value = reader.GetString() ?? throw new JsonException("Value was null.");
foreach (var field in typeToConvert.GetFields())
{
var attribute = field.GetCustomAttribute<JsonPropertyNameAttribute>();
if (attribute?.Name == value)
{
return (T)Enum.Parse(typeToConvert, field.Name);
}
}
throw new JsonException($"Unable to convert \"{value}\" to enum {typeToConvert}.");
}
public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
{
var field = value.GetType().GetField(value.ToString());
var attribute = field.GetCustomAttribute<JsonPropertyNameAttribute>();
if (attribute != null)
{
writer.WriteStringValue(attribute.Name);
}
else
{
writer.WriteStringValue(value.ToString());
}
}
}

View File

@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatCompletionRequest.cs
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class ChatCompletionRequest
{
/// <summary>
/// Initializes a new instance of the <see cref="ChatCompletionRequest" /> class.
/// </summary>
/// <param name="model">ID of the model to use. You can use the [List Available Models](/api#operation/listModels) API to see all of your available models, or see our [Model overview](/models) for model descriptions. (required).</param>
/// <param name="messages">The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be &#x60;user&#x60; or &#x60;system&#x60;. (required).</param>
/// <param name="temperature">What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or &#x60;top_p&#x60; but not both. (default to 0.7M).</param>
/// <param name="topP">Nucleus sampling, where the model considers the results of the tokens with &#x60;top_p&#x60; probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or &#x60;temperature&#x60; but not both. (default to 1M).</param>
/// <param name="maxTokens">The maximum number of tokens to generate in the completion. The token count of your prompt plus &#x60;max_tokens&#x60; cannot exceed the model&#39;s context length. .</param>
/// <param name="stream">Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. (default to false).</param>
/// <param name="safePrompt">Whether to inject a safety prompt before all conversations. (default to false).</param>
/// <param name="randomSeed">The seed to use for random sampling. If set, different calls will generate deterministic results. .</param>
public ChatCompletionRequest(string? model = default(string), List<ChatMessage>? messages = default(List<ChatMessage>), float? temperature = 0.7f, float? topP = 1f, int? maxTokens = default(int?), bool? stream = false, bool safePrompt = false, int? randomSeed = default(int?))
{
// to ensure "model" is required (not null)
if (model == null)
{
throw new ArgumentNullException("model is a required property for ChatCompletionRequest and cannot be null");
}
this.Model = model;
// to ensure "messages" is required (not null)
if (messages == null)
{
throw new ArgumentNullException("messages is a required property for ChatCompletionRequest and cannot be null");
}
this.Messages = messages;
// use default value if no "temperature" provided
this.Temperature = temperature ?? 0.7f;
// use default value if no "topP" provided
this.TopP = topP ?? 1f;
this.MaxTokens = maxTokens;
// use default value if no "stream" provided
this.Stream = stream ?? false;
this.SafePrompt = safePrompt;
this.RandomSeed = randomSeed;
}
/// <summary>
/// ID of the model to use. You can use the [List Available Models](/api#operation/listModels) API to see all of your available models, or see our [Model overview](/models) for model descriptions.
/// </summary>
/// <value>ID of the model to use. You can use the [List Available Models](/api#operation/listModels) API to see all of your available models, or see our [Model overview](/models) for model descriptions. </value>
/// <example>mistral-tiny</example>
[JsonPropertyName("model")]
public string Model { get; set; }
/// <summary>
/// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be &#x60;user&#x60; or &#x60;system&#x60;.
/// </summary>
/// <value>The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be &#x60;user&#x60; or &#x60;system&#x60;. </value>
/// <example>[{&quot;role&quot;:&quot;user&quot;,&quot;content&quot;:&quot;What is the best French cheese?&quot;}]</example>
[JsonPropertyName("messages")]
public List<ChatMessage> Messages { get; set; }
/// <summary>
/// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or &#x60;top_p&#x60; but not both.
/// </summary>
/// <value>What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or &#x60;top_p&#x60; but not both. </value>
/// <example>0.7</example>
[JsonPropertyName("temperature")]
public float? Temperature { get; set; }
/// <summary>
/// Nucleus sampling, where the model considers the results of the tokens with &#x60;top_p&#x60; probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or &#x60;temperature&#x60; but not both.
/// </summary>
/// <value>Nucleus sampling, where the model considers the results of the tokens with &#x60;top_p&#x60; probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or &#x60;temperature&#x60; but not both. </value>
/// <example>1</example>
[JsonPropertyName("top_p")]
public float? TopP { get; set; }
/// <summary>
/// The maximum number of tokens to generate in the completion. The token count of your prompt plus &#x60;max_tokens&#x60; cannot exceed the model&#39;s context length.
/// </summary>
/// <value>The maximum number of tokens to generate in the completion. The token count of your prompt plus &#x60;max_tokens&#x60; cannot exceed the model&#39;s context length. </value>
/// <example>16</example>
[JsonPropertyName("max_tokens")]
public int? MaxTokens { get; set; }
/// <summary>
/// Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON.
/// </summary>
/// <value>Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON. </value>
[JsonPropertyName("stream")]
public bool? Stream { get; set; }
/// <summary>
/// Whether to inject a safety prompt before all conversations.
/// </summary>
/// <value>Whether to inject a safety prompt before all conversations. </value>
[JsonPropertyName("safe_prompt")]
public bool SafePrompt { get; set; }
/// <summary>
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
/// </summary>
/// <value>The seed to use for random sampling. If set, different calls will generate deterministic results. </value>
[JsonPropertyName("random_seed")]
public int? RandomSeed { get; set; }
[JsonPropertyName("tools")]
public List<FunctionTool>? Tools { get; set; }
[JsonPropertyName("tool_choice")]
public ToolChoiceEnum? ToolChoice { get; set; }
[JsonPropertyName("response_format")]
public ResponseFormat? ResponseFormat { get; set; } = null;
}

View File

@@ -0,0 +1,47 @@
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class ChatCompletionResponse
{
/// <summary>
/// Gets or Sets Id
/// </summary>
/// <example>cmpl-e5cc70bb28c444948073e77776eb30ef</example>
[JsonPropertyName("id")]
public string? Id { get; set; }
/// <summary>
/// Gets or Sets VarObject
/// </summary>
/// <example>chat.completion</example>
[JsonPropertyName("object")]
public string? VarObject { get; set; }
/// <summary>
/// Gets or Sets Created
/// </summary>
/// <example>1702256327</example>
[JsonPropertyName("created")]
public int Created { get; set; }
/// <summary>
/// Gets or Sets Model
/// </summary>
/// <example>mistral-tiny</example>
[JsonPropertyName("model")]
public string? Model { get; set; }
/// <summary>
/// Gets or Sets Choices
/// </summary>
[JsonPropertyName("choices")]
public List<Choice>? Choices { get; set; }
/// <summary>
/// Gets or Sets Usage
/// </summary>
[JsonPropertyName("usage")]
public Usage? Usage { get; set; }
}

View File

@@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ChatMessage.cs
using System.Collections.Generic;
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class ChatMessage
{
/// <summary>
/// Initializes a new instance of the <see cref="ChatMessage" /> class.
/// </summary>
/// <param name="role">role.</param>
/// <param name="content">content.</param>
public ChatMessage(RoleEnum? role = default(RoleEnum?), string? content = null)
{
this.Role = role;
this.Content = content;
}
[JsonConverter(typeof(JsonPropertyNameEnumConverter<RoleEnum>))]
public enum RoleEnum
{
/// <summary>
/// Enum System for value: system
/// </summary>
[JsonPropertyName("system")]
//[EnumMember(Value = "system")]
System = 1,
/// <summary>
/// Enum User for value: user
/// </summary>
[JsonPropertyName("user")]
//[EnumMember(Value = "user")]
User = 2,
/// <summary>
/// Enum Assistant for value: assistant
/// </summary>
[JsonPropertyName("assistant")]
//[EnumMember(Value = "assistant")]
Assistant = 3,
[JsonPropertyName("tool")]
Tool = 4,
}
/// <summary>
/// Gets or Sets Role
/// </summary>
[JsonPropertyName("role")]
public RoleEnum? Role { get; set; }
/// <summary>
/// Gets or Sets Content
/// </summary>
[JsonPropertyName("content")]
public string? Content { get; set; }
/// <summary>
/// Gets or Sets name for tool calls
/// </summary>
[JsonPropertyName("name")]
public string? Name { get; set; }
[JsonPropertyName("tool_calls")]
public List<FunctionContent>? ToolCalls { get; set; }
}
public class FunctionContent
{
public FunctionContent(FunctionCall function)
{
this.Function = function;
}
[JsonPropertyName("function")]
public FunctionCall Function { get; set; }
public class FunctionCall
{
public FunctionCall(string name, string arguments)
{
this.Name = name;
this.Arguments = arguments;
}
[JsonPropertyName("name")]
public string Name { get; set; }
[JsonPropertyName("arguments")]
public string Arguments { get; set; }
}
}

View File

@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Choice.cs
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class Choice
{
[JsonConverter(typeof(JsonPropertyNameEnumConverter<FinishReasonEnum>))]
public enum FinishReasonEnum
{
/// <summary>
/// Enum Stop for value: stop
/// </summary>
[JsonPropertyName("stop")]
Stop = 1,
/// <summary>
/// Enum Length for value: length
/// </summary>
[JsonPropertyName("length")]
Length = 2,
/// <summary>
/// Enum ModelLength for value: model_length
/// </summary>
[JsonPropertyName("model_length")]
ModelLength = 3,
[JsonPropertyName("error")]
Error = 4,
[JsonPropertyName("tool_calls")]
ToolCalls = 5,
}
/// <summary>
/// Gets or Sets FinishReason
/// </summary>
[JsonPropertyName("finish_reason")]
public FinishReasonEnum? FinishReason { get; set; }
[JsonPropertyName("index")]
public int Index { get; set; }
/// <summary>
/// Gets or Sets Message
/// </summary>
[JsonPropertyName("message")]
public ChatMessage? Message { get; set; }
/// <summary>
/// Gets or Sets Delta
/// </summary>
[JsonPropertyName("delta")]
public ChatMessage? Delta { get; set; }
}

View File

@@ -0,0 +1,36 @@
using System.Text.Json.Serialization;
namespace AutoGen.Mistral
{
public class Error
{
public Error(string type, string message, string? param = default(string), string? code = default(string))
{
Type = type;
Message = message;
Param = param;
Code = code;
}
[JsonPropertyName("type")]
public string Type { get; set; }
/// <summary>
/// Gets or Sets Message
/// </summary>
[JsonPropertyName("message")]
public string Message { get; set; }
/// <summary>
/// Gets or Sets Param
/// </summary>
[JsonPropertyName("param")]
public string? Param { get; set; }
/// <summary>
/// Gets or Sets Code
/// </summary>
[JsonPropertyName("code")]
public string? Code { get; set; }
}
}

View File

@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ErrorResponse.cs
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class ErrorResponse
{
public ErrorResponse(Error error)
{
Error = error;
}
/// <summary>
/// Gets or Sets Error
/// </summary>
[JsonPropertyName("error")]
public Error Error { get; set; }
}

View File

@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionDefinition.cs
using System.Text.Json.Serialization;
using Json.Schema;
namespace AutoGen.Mistral;
public class FunctionDefinition
{
public FunctionDefinition(string name, string description, JsonSchema? parameters = default)
{
Name = name;
Description = description;
Parameters = parameters;
}
[JsonPropertyName("name")]
public string Name { get; set; }
[JsonPropertyName("description")]
public string Description { get; set; }
[JsonPropertyName("parameters")]
public JsonSchema? Parameters { get; set; }
}

View File

@@ -0,0 +1,61 @@
using System;
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class Model
{
/// <summary>
/// Initializes a new instance of the <see cref="Model" /> class.
/// </summary>
/// <param name="id">id (required).</param>
/// <param name="varObject">varObject (required).</param>
/// <param name="created">created (required).</param>
/// <param name="ownedBy">ownedBy (required).</param>
public Model(string? id = default(string), string? varObject = default(string), int created = default(int), string? ownedBy = default(string))
{
// to ensure "id" is required (not null)
if (id == null)
{
throw new ArgumentNullException("id is a required property for Model and cannot be null");
}
this.Id = id;
// to ensure "varObject" is required (not null)
if (varObject == null)
{
throw new ArgumentNullException("varObject is a required property for Model and cannot be null");
}
this.VarObject = varObject;
this.Created = created;
// to ensure "ownedBy" is required (not null)
if (ownedBy == null)
{
throw new ArgumentNullException("ownedBy is a required property for Model and cannot be null");
}
this.OwnedBy = ownedBy;
}
/// <summary>
/// Gets or Sets Id
/// </summary>
[JsonPropertyName("id")]
public string Id { get; set; }
/// <summary>
/// Gets or Sets VarObject
/// </summary>
[JsonPropertyName("object")]
public string VarObject { get; set; }
/// <summary>
/// Gets or Sets Created
/// </summary>
[JsonPropertyName("created")]
public int Created { get; set; }
/// <summary>
/// Gets or Sets OwnedBy
/// </summary>
[JsonPropertyName("owned_by")]
public string OwnedBy { get; set; }
}

View File

@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ResponseFormat.cs
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class ResponseFormat
{
[JsonPropertyName("type")]
public string ResponseFormatType { get; set; } = "json_object";
}

View File

@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Tool.cs
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public abstract class ToolBase
{
[JsonPropertyName("type")]
public string Type { get; set; }
public ToolBase(string type)
{
Type = type;
}
}
public class FunctionTool : ToolBase
{
public FunctionTool(FunctionDefinition function)
: base("function")
{
Function = function;
}
[JsonPropertyName("function")]
public FunctionDefinition Function { get; set; }
}
[JsonConverter(typeof(JsonPropertyNameEnumConverter<ToolChoiceEnum>))]
public enum ToolChoiceEnum
{
/// <summary>
/// Auto-detect whether to call a function.
/// </summary>
[JsonPropertyName("auto")]
Auto = 0,
/// <summary>
/// Won't call a function.
/// </summary>
[JsonPropertyName("none")]
None,
/// <summary>
/// Force to call a function.
/// </summary>
[JsonPropertyName("any")]
Any,
}

View File

@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Usage.cs
using System.Text.Json.Serialization;
namespace AutoGen.Mistral;
public class Usage
{
[JsonPropertyName("prompt_tokens")]
public int PromptTokens { get; set; }
/// <summary>
/// Gets or Sets CompletionTokens
/// </summary>
/// <example>93</example>
[JsonPropertyName("completion_tokens")]
public int CompletionTokens { get; set; }
/// <summary>
/// Gets or Sets TotalTokens
/// </summary>
/// <example>107</example>
[JsonPropertyName("total_tokens")]
public int TotalTokens { get; set; }
}

View File

@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContractExtension.cs
using System;
using System.Collections.Generic;
using AutoGen.Core;
using Json.Schema;
using Json.Schema.Generation;
namespace AutoGen.Mistral.Extension;
public static class FunctionContractExtension
{
/// <summary>
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDefinition"/> that can be used in funciton call.
/// </summary>
/// <param name="functionContract">function contract</param>
/// <returns><see cref="FunctionDefinition"/></returns>
public static FunctionDefinition ToMistralFunctionDefinition(this FunctionContract functionContract)
{
var functionDefinition = new FunctionDefinition(functionContract.Name ?? throw new Exception("Function name cannot be null"), functionContract.Description ?? throw new Exception("Function description cannot be null"));
var requiredParameterNames = new List<string>();
var propertiesSchemas = new Dictionary<string, JsonSchema>();
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
foreach (var param in functionContract.Parameters ?? [])
{
if (param.Name is null)
{
throw new InvalidOperationException("Parameter name cannot be null");
}
var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
if (param.Description != null)
{
schemaBuilder = schemaBuilder.Description(param.Description);
}
if (param.IsRequired)
{
requiredParameterNames.Add(param.Name);
}
var schema = schemaBuilder.Build();
propertiesSchemas[param.Name] = schema;
}
propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
var option = new System.Text.Json.JsonSerializerOptions()
{
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
};
functionDefinition.Parameters = propertySchemaBuilder.Build();
return functionDefinition;
}
}

View File

@@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MistralAgentExtension.cs
using AutoGen.Core;
namespace AutoGen.Mistral.Extension;
public static class MistralAgentExtension
{
/// <summary>
/// Register a <see cref="MistralChatMessageConnector"/> to support more AutoGen message types.
/// </summary>
public static MiddlewareStreamingAgent<MistralClientAgent> RegisterMessageConnector(
this MistralClientAgent agent, MistralChatMessageConnector? connector = null)
{
if (connector == null)
{
connector = new MistralChatMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector)
.RegisterMiddleware(connector);
}
/// <summary>
/// Register a <see cref="MistralChatMessageConnector"/> to support more AutoGen message types.
/// </summary>
public static MiddlewareStreamingAgent<MistralClientAgent> RegisterMessageConnector(
this MiddlewareStreamingAgent<MistralClientAgent> agent, MistralChatMessageConnector? connector = null)
{
if (connector == null)
{
connector = new MistralChatMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector)
.RegisterMiddleware(connector);
}
}

View File

@@ -0,0 +1,324 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MistralChatMessageConnector.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.Core;
namespace AutoGen.Mistral;
public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware
{
public string? Name => nameof(MistralChatMessageConnector);
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
{
return Task.FromResult(StreamingInvoke(context, agent, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> StreamingInvoke(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var chunks = new List<ChatCompletionResponse>();
await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken))
{
if (reply is IStreamingMessage<ChatCompletionResponse> chatMessage)
{
chunks.Add(chatMessage.Content);
var response = ProcessChatCompletionResponse(chatMessage, agent);
if (response is not null)
{
yield return response;
}
}
else
{
yield return reply;
}
}
// if chunks is not empty, then return the aggregate message as the last message
// this is to meet the requirement of streaming call api
// where the last message should be the same result of non-streaming call api
if (chunks.Count == 0)
{
yield break;
}
var lastResponse = chunks.Last() ?? throw new ArgumentNullException("chunks.Last()");
var finalResponse = chunks.First() ?? throw new ArgumentNullException("chunks.First()");
if (lastResponse.Choices!.First().FinishReason == Choice.FinishReasonEnum.ToolCalls)
{
// process as tool call message
foreach (var response in chunks)
{
if (finalResponse.Choices!.First().Message is null)
{
finalResponse.Choices!.First().Message = response.Choices!.First().Delta;
if (finalResponse.Choices!.First().Message!.ToolCalls is null)
{
finalResponse.Choices!.First().Message!.ToolCalls = new List<FunctionContent>();
}
}
if (response.Choices!.First().Delta!.ToolCalls is not null)
{
finalResponse.Choices!.First().Message!.ToolCalls!.AddRange(response.Choices!.First().Delta!.ToolCalls!);
}
finalResponse.Choices!.First().FinishReason = response.Choices!.First().FinishReason;
// the usage information will be included in the last message
if (response.Usage is not null)
{
finalResponse.Usage = response.Usage;
}
}
}
else
{
// process as plain text message
foreach (var response in chunks)
{
if (finalResponse.Choices!.First().Message is null)
{
finalResponse.Choices!.First().Message = response.Choices!.First().Delta;
}
finalResponse.Choices!.First().Message!.Content += response.Choices!.First().Delta!.Content;
finalResponse.Choices!.First().FinishReason = response.Choices!.First().FinishReason;
// the usage information will be included in the last message
if (response.Usage is not null)
{
finalResponse.Usage = response.Usage;
}
}
}
yield return PostProcessMessage(finalResponse, agent);
}
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessages = ProcessMessage(messages, agent);
var response = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
if (response is IMessage<ChatCompletionResponse> chatMessage)
{
return PostProcessMessage(chatMessage.Content, agent);
}
else
{
return response;
}
}
private IEnumerable<IMessage> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
{
return messages.SelectMany<IMessage, IMessage>(m =>
{
if (m is IMessage<ChatMessage> chatMessage)
{
return [MessageEnvelope.Create(chatMessage.Content, from: chatMessage.From)];
}
else
{
return m switch
{
TextMessage textMessage => ProcessTextMessage(textMessage, agent),
ToolCallMessage toolCallMessage when (toolCallMessage.From is null || toolCallMessage.From == agent.Name) => ProcessToolCallMessage(toolCallMessage, agent),
ToolCallResultMessage toolCallResultMessage => ProcessToolCallResultMessage(toolCallResultMessage, agent),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => ProcessFunctionCallMiddlewareMessage(aggregateMessage, agent), // message type support for functioncall middleware
_ => [m],
};
}
});
}
private IMessage PostProcessMessage(ChatCompletionResponse response, IAgent from)
{
if (response.Choices is null)
{
throw new ArgumentNullException("response.Choices");
}
if (response.Choices?.Count != 1)
{
throw new NotSupportedException("response.Choices.Count != 1");
}
var choice = response.Choices[0];
var finishReason = choice.FinishReason ?? throw new ArgumentNullException("choice.FinishReason");
if (finishReason == Choice.FinishReasonEnum.Stop || finishReason == Choice.FinishReasonEnum.Length)
{
return new TextMessage(Role.Assistant, choice.Message?.Content ?? throw new ArgumentNullException("choice.Message.Content"), from: from.Name);
}
else if (finishReason == Choice.FinishReasonEnum.ToolCalls)
{
var functionContents = choice.Message?.ToolCalls ?? throw new ArgumentNullException("choice.Message.ToolCalls");
var toolCalls = functionContents.Select(f => new ToolCall(f.Function.Name, f.Function.Arguments)).ToList();
return new ToolCallMessage(toolCalls, from: from.Name);
}
else
{
throw new NotSupportedException($"FinishReason {finishReason} is not supported");
}
}
private IStreamingMessage? ProcessChatCompletionResponse(IStreamingMessage<ChatCompletionResponse> message, IAgent agent)
{
var response = message.Content;
if (response.VarObject != "chat.completion.chunk")
{
throw new NotSupportedException($"VarObject {response.VarObject} is not supported");
}
if (response.Choices is null)
{
throw new ArgumentNullException("response.Choices");
}
if (response.Choices?.Count != 1)
{
throw new NotSupportedException("response.Choices.Count != 1");
}
var choice = response.Choices[0];
var delta = choice.Delta;
// process text message if delta.content is not null
if (delta?.Content is string content)
{
return new TextMessageUpdate(role: Role.Assistant, content, from: agent.Name);
}
else if (delta?.ToolCalls is var toolCalls && toolCalls is { Count: 1 })
{
var toolCall = toolCalls[0];
var functionContent = toolCall.Function;
return new ToolCallMessageUpdate(functionContent.Name, functionContent.Arguments, from: agent.Name);
}
else
{
return null;
}
}
private IEnumerable<IMessage<ChatMessage>> ProcessTextMessage(TextMessage textMessage, IAgent agent)
{
IEnumerable<ChatMessage> messages;
// check if textMessage is system message
if (textMessage.Role == Role.System)
{
messages = [new ChatMessage(ChatMessage.RoleEnum.System, textMessage.Content)];
}
else if (textMessage.From == agent.Name)
{
// if this message is from agent iteself, then its role should be assistant
messages = [new ChatMessage(ChatMessage.RoleEnum.Assistant, textMessage.Content)];
}
else if (textMessage.From is null)
{
// if from is null, then process the message based on the role
if (textMessage.Role == Role.User)
{
messages = [new ChatMessage(ChatMessage.RoleEnum.User, textMessage.Content)];
}
else if (textMessage.Role == Role.Assistant)
{
messages = [new ChatMessage(ChatMessage.RoleEnum.Assistant, textMessage.Content)];
}
else
{
throw new NotSupportedException($"Role {textMessage.Role} is not supported");
}
}
else
{
// if from is not null, then the message is from user
messages = [new ChatMessage(ChatMessage.RoleEnum.User, textMessage.Content)];
}
return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: textMessage.From));
}
private IEnumerable<IMessage<ChatMessage>> ProcessToolCallResultMessage(ToolCallResultMessage toolCallResultMessage, IAgent agent)
{
var from = toolCallResultMessage.From;
var messages = new List<ChatMessage>();
foreach (var toolCall in toolCallResultMessage.ToolCalls)
{
if (toolCall.Result is null)
{
continue;
}
var message = new ChatMessage(ChatMessage.RoleEnum.Tool, content: toolCall.Result)
{
Name = toolCall.FunctionName,
};
messages.Add(message);
}
return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: toolCallResultMessage.From));
}
/// <summary>
/// Process the aggregate message from function call middleware. If the message is from another agent, this message will be interpreted as an ordinary plain <see cref="TextMessage"/>.
/// If the message is from the same agent or the from field is empty, this message will be expanded to the tool call message and tool call result message.
/// </summary>
/// <param name="aggregateMessage"></param>
/// <param name="agent"></param>
/// <returns></returns>
/// <exception cref="NotSupportedException"></exception>
private IEnumerable<IMessage<ChatMessage>> ProcessFunctionCallMiddlewareMessage(AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage, IAgent agent)
{
if (aggregateMessage.From is string from && from != agent.Name)
{
// if the message is from another agent, then interpret it as a plain text message
// where the content of the plain text message is the content of the tool call result message
var contents = aggregateMessage.Message2.ToolCalls.Select(t => t.Result);
var messages = contents.Select(c => new ChatMessage(ChatMessage.RoleEnum.Assistant, c));
return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: from));
}
// if the message is from the same agent or the from field is empty, then expand the message to tool call message and tool call result message
var toolCallMessage = aggregateMessage.Message1;
var toolCallResultMessage = aggregateMessage.Message2;
return this.ProcessToolCallMessage(toolCallMessage, agent).Concat(this.ProcessToolCallResultMessage(toolCallResultMessage, agent));
}
private IEnumerable<IMessage<ChatMessage>> ProcessToolCallMessage(ToolCallMessage toolCallMessage, IAgent agent)
{
IEnumerable<ChatMessage> messages;
// the scenario is not support when tool call message is from another agent
if (toolCallMessage.From is string from && from != agent.Name)
{
throw new NotSupportedException("Tool call message from another agent is not supported");
}
// convert tool call message to chat message
var chatMessage = new ChatMessage(ChatMessage.RoleEnum.Assistant);
chatMessage.ToolCalls = new List<FunctionContent>();
foreach (var toolCall in toolCallMessage.ToolCalls)
{
var functionCall = new FunctionContent.FunctionCall(toolCall.FunctionName, toolCall.FunctionArguments);
var functionContent = new FunctionContent(functionCall);
chatMessage.ToolCalls.Add(functionContent);
}
messages = [chatMessage];
return messages.Select(m => new MessageEnvelope<ChatMessage>(m, from: toolCallMessage.From));
}
}

View File

@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MistralAIModelID.cs
namespace AutoGen.Mistral;
public class MistralAIModelID
{
public const string OPEN_MISTRAL_7B = "open-mistral-7b";
public const string OPEN_MISTRAL_8X7B = "open-mixtral-8x7b";
public const string OPEN_MISTRAL_8X22B = "open-mixtral-8x22b";
public const string MISTRAL_SMALL_LATEST = "mistral-small-latest";
public const string MISTRAL_MEDIUM_LATEST = "mistral-medium-latest";
public const string MISTRAL_LARGE_LATEST = "mistral-large-latest";
}

View File

@@ -0,0 +1,168 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MistralClient.cs
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Security.Authentication;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
namespace AutoGen.Mistral;
public class MistralClient : IDisposable
{
private readonly HttpClient _httpClient;
private readonly string baseUrl = "https://api.mistral.ai/v1";
public MistralClient(string apiKey, string? baseUrl = null)
{
_httpClient = new HttpClient();
_httpClient.DefaultRequestHeaders.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/json"));
_httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {apiKey}");
this.baseUrl = baseUrl ?? this.baseUrl;
}
public MistralClient(HttpClient httpClient, string? baseUrl = null)
{
_httpClient = httpClient;
_httpClient.DefaultRequestHeaders.Accept.Add(new System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("application/json"));
this.baseUrl = baseUrl ?? this.baseUrl;
}
public async Task<ChatCompletionResponse> CreateChatCompletionsAsync(ChatCompletionRequest chatCompletionRequest)
{
chatCompletionRequest.Stream = false;
var response = await HttpRequestRaw(HttpMethod.Post, chatCompletionRequest);
response.EnsureSuccessStatusCode();
var responseStream = await response.Content.ReadAsStreamAsync();
return await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(responseStream) ?? throw new Exception("Failed to deserialize response");
}
public async IAsyncEnumerable<ChatCompletionResponse> StreamingChatCompletionsAsync(ChatCompletionRequest chatCompletionRequest)
{
chatCompletionRequest.Stream = true;
var response = await HttpRequestRaw(HttpMethod.Post, chatCompletionRequest, streaming: true);
using var stream = await response.Content.ReadAsStreamAsync();
using StreamReader reader = new StreamReader(stream);
string line;
SseEvent currentEvent = new SseEvent();
while ((line = await reader.ReadLineAsync()) != null)
{
if (!string.IsNullOrEmpty(line))
{
currentEvent.Data = line.Substring("data:".Length).Trim();
}
else // an empty line indicates the end of an event
{
if (currentEvent.Data == "[DONE]")
{
continue;
}
else if (currentEvent.EventType == null)
{
var res = await JsonSerializer.DeserializeAsync<ChatCompletionResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data))) ?? throw new Exception("Failed to deserialize response");
yield return res;
}
else if (currentEvent.EventType != null)
{
var res = await JsonSerializer.DeserializeAsync<ErrorResponse>(
new MemoryStream(Encoding.UTF8.GetBytes(currentEvent.Data)));
throw new Exception(res?.Error.Message);
}
// Reset the current event for the next one
currentEvent = new SseEvent();
}
}
}
protected async Task<HttpResponseMessage> HttpRequestRaw(HttpMethod verb, object postData, bool streaming = false)
{
var url = $"{baseUrl}/chat/completions";
HttpResponseMessage response;
string resultAsString;
HttpRequestMessage req = new HttpRequestMessage(verb, url);
if (postData != null)
{
if (postData is HttpContent)
{
req.Content = postData as HttpContent;
}
else
{
string jsonContent = JsonSerializer.Serialize(postData,
new JsonSerializerOptions() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull });
var stringContent = new StringContent(jsonContent, Encoding.UTF8, "application/json");
req.Content = stringContent;
}
}
response = await this._httpClient.SendAsync(req,
streaming ? HttpCompletionOption.ResponseHeadersRead : HttpCompletionOption.ResponseContentRead);
if (response.IsSuccessStatusCode)
{
return response;
}
else
{
try
{
resultAsString = await response.Content.ReadAsStringAsync();
}
catch (Exception e)
{
resultAsString =
"Additionally, the following error was thrown when attempting to read the response content: " +
e.ToString();
}
if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized)
{
throw new AuthenticationException(
"Mistral rejected your authorization, most likely due to an invalid API Key. Full API response follows: " +
resultAsString);
}
else if (response.StatusCode == System.Net.HttpStatusCode.InternalServerError)
{
throw new HttpRequestException(
"Mistral had an internal server error, which can happen occasionally. Please retry your request. " +
GetErrorMessage(resultAsString, response, url, url));
}
else
{
throw new HttpRequestException(GetErrorMessage(resultAsString, response, url, url));
}
}
}
private string GetErrorMessage(string resultAsString, HttpResponseMessage response, string name, string description = "")
{
return $"Error at {name} ({description}) with HTTP status code: {response.StatusCode}. Content: {resultAsString ?? "<no content>"}";
}
public void Dispose()
{
_httpClient.Dispose();
}
public class SseEvent
{
public SseEvent(string? eventType = null, string? data = null)
{
EventType = eventType;
Data = data;
}
public string? EventType { get; set; }
public string? Data { get; set; }
}
}

View File

@@ -0,0 +1,119 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GPTAgent.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
namespace AutoGen.OpenAI;
/// <summary>
/// GPT agent that can be used to connect to OpenAI chat models like GPT-3.5, GPT-4, etc.
/// <para><see cref="GPTAgent" /> supports the following message types as input:</para>
/// <para>- <see cref="TextMessage"/></para>
/// <para>- <see cref="ImageMessage"/></para>
/// <para>- <see cref="MultiModalMessage"/></para>
/// <para>- <see cref="ToolCallMessage"/></para>
/// <para>- <see cref="ToolCallResultMessage"/></para>
/// <para>- <see cref="Message"/></para>
/// <para>- <see cref="IMessage{ChatRequestMessage}"/> where T is <see cref="ChatRequestMessage"/></para>
/// <para>- <see cref="AggregateMessage{TMessage1, TMessage2}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/></para>
///
/// <para><see cref="GPTAgent" /> returns the following message types:</para>
/// <para>- <see cref="TextMessage"/></para>
/// <para>- <see cref="ToolCallMessage"/></para>
/// <para>- <see cref="AggregateMessage{TMessage1, TMessage2}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/></para>
/// </summary>
public class GPTAgent : IStreamingAgent
{
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
private readonly OpenAIClient openAIClient;
private readonly string? modelName;
private readonly OpenAIChatAgent _innerAgent;
public GPTAgent(
string name,
string systemMessage,
ILLMConfig config,
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null)
{
openAIClient = config switch
{
AzureOpenAIConfig azureConfig => new OpenAIClient(new Uri(azureConfig.Endpoint), new Azure.AzureKeyCredential(azureConfig.ApiKey)),
OpenAIConfig openAIConfig => new OpenAIClient(openAIConfig.ApiKey),
_ => throw new ArgumentException($"Unsupported config type {config.GetType()}"),
};
modelName = config switch
{
AzureOpenAIConfig azureConfig => azureConfig.DeploymentName,
OpenAIConfig openAIConfig => openAIConfig.ModelId,
_ => throw new ArgumentException($"Unsupported config type {config.GetType()}"),
};
_innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, seed, responseFormat, functions);
Name = name;
this.functionMap = functionMap;
}
public GPTAgent(
string name,
string systemMessage,
OpenAIClient openAIClient,
string modelName,
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null)
{
this.openAIClient = openAIClient;
this.modelName = modelName;
Name = name;
this.functionMap = functionMap;
_innerAgent = new OpenAIChatAgent(openAIClient, name, modelName, systemMessage, temperature, maxTokens, seed, responseFormat, functions);
}
public string Name { get; }
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var agent = this._innerAgent
.RegisterMessageConnector();
if (this.functionMap is not null)
{
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
agent = agent.RegisterMiddleware(functionMapMiddleware);
}
return await agent.GenerateReplyAsync(messages, options, cancellationToken);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var agent = this._innerAgent
.RegisterMessageConnector();
if (this.functionMap is not null)
{
var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap);
agent = agent.RegisterStreamingMiddleware(functionMapMiddleware);
}
return await agent.GenerateStreamingReplyAsync(messages, options, cancellationToken);
}
}

View File

@@ -0,0 +1,158 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OpenAIChatAgent.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI.Extension;
using Azure.AI.OpenAI;
namespace AutoGen.OpenAI;
/// <summary>
/// OpenAI client agent. This agent is a thin wrapper around <see cref="OpenAIClient"/> to provide a simple interface for chat completions.
/// To better work with other agents, it's recommended to use <see cref="GPTAgent"/> which supports more message types and have a better compatibility with other agents.
/// <para><see cref="OpenAIChatAgent" /> supports the following message types:</para>
/// <list type="bullet">
/// <item>
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatRequestMessage"/>: chat request message.
/// </item>
/// </list>
/// <para><see cref="OpenAIChatAgent" /> returns the following message types:</para>
/// <list type="bullet">
/// <item>
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="ChatResponseMessage"/>: chat response message.
/// <see cref="MessageEnvelope{T}"/> where T is <see cref="StreamingChatCompletionsUpdate"/>: streaming chat completions update.
/// </item>
/// </list>
/// </summary>
public class OpenAIChatAgent : IStreamingAgent
{
private readonly OpenAIClient openAIClient;
private readonly string modelName;
private readonly float _temperature;
private readonly int _maxTokens = 1024;
private readonly IEnumerable<FunctionDefinition>? _functions;
private readonly string _systemMessage;
private readonly ChatCompletionsResponseFormat? _responseFormat;
private readonly int? _seed;
/// <summary>
/// Create a new instance of <see cref="OpenAIChatAgent"/>.
/// </summary>
/// <param name="openAIClient">openai client</param>
/// <param name="name">agent name</param>
/// <param name="modelName">model name. e.g. gpt-turbo-3.5</param>
/// <param name="systemMessage">system message</param>
/// <param name="temperature">temperature</param>
/// <param name="maxTokens">max tokens to generated</param>
/// <param name="responseFormat">response format, set it to <see cref="ChatCompletionsResponseFormat.JsonObject"/> to enable json mode.</param>
/// <param name="seed">seed to use, set it to enable deterministic output</param>
/// <param name="functions">functions</param>
public OpenAIChatAgent(
OpenAIClient openAIClient,
string name,
string modelName,
string systemMessage = "You are a helpful AI assistant",
float temperature = 0.7f,
int maxTokens = 1024,
int? seed = null,
ChatCompletionsResponseFormat? responseFormat = null,
IEnumerable<FunctionDefinition>? functions = null)
{
this.openAIClient = openAIClient;
this.modelName = modelName;
this.Name = name;
_temperature = temperature;
_maxTokens = maxTokens;
_functions = functions;
_systemMessage = systemMessage;
_responseFormat = responseFormat;
_seed = seed;
}
public string Name { get; }
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var settings = this.CreateChatCompletionsOptions(options, messages);
var reply = await this.openAIClient.GetChatCompletionsAsync(settings, cancellationToken);
return new MessageEnvelope<ChatResponseMessage>(reply.Value.Choices.First().Message, from: this.Name);
}
public Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
return Task.FromResult(this.StreamingReplyAsync(messages, options, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> StreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var settings = this.CreateChatCompletionsOptions(options, messages);
var response = await this.openAIClient.GetChatCompletionsStreamingAsync(settings);
await foreach (var update in response.WithCancellation(cancellationToken))
{
if (update.ChoiceIndex > 0)
{
throw new InvalidOperationException("Only one choice is supported in streaming response");
}
yield return new MessageEnvelope<StreamingChatCompletionsUpdate>(update, from: this.Name);
}
}
private ChatCompletionsOptions CreateChatCompletionsOptions(GenerateReplyOptions? options, IEnumerable<IMessage> messages)
{
var oaiMessages = messages.Select(m => m switch
{
IMessage<ChatRequestMessage> chatRequestMessage => chatRequestMessage.Content,
_ => throw new ArgumentException("Invalid message type")
});
// add system message if there's no system message in messages
if (!oaiMessages.Any(m => m is ChatRequestSystemMessage))
{
oaiMessages = new[] { new ChatRequestSystemMessage(_systemMessage) }.Concat(oaiMessages);
}
var settings = new ChatCompletionsOptions(this.modelName, oaiMessages)
{
MaxTokens = options?.MaxToken ?? _maxTokens,
Temperature = options?.Temperature ?? _temperature,
ResponseFormat = _responseFormat,
Seed = _seed,
};
var openAIFunctionDefinitions = options?.Functions?.Select(f => f.ToOpenAIFunctionDefinition());
var functions = openAIFunctionDefinitions ?? _functions;
if (functions is not null && functions.Count() > 0)
{
foreach (var f in functions)
{
settings.Tools.Add(new ChatCompletionsFunctionToolDefinition(f));
}
}
if (options?.StopSequence is var sequence && sequence is { Length: > 0 })
{
foreach (var seq in sequence)
{
settings.StopSequences.Add(seq);
}
}
return settings;
}
}

View File

@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>AutoGen.OpenAI</RootNamespace>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.OpenAI</Title>
<Description>
OpenAI Intergration for AutoGen.
</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="$(AzureOpenAIVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AzureOpenAIConfig.cs
namespace AutoGen.OpenAI;
public class AzureOpenAIConfig : ILLMConfig
{
public AzureOpenAIConfig(string endpoint, string deploymentName, string apiKey, string? modelId = null)
{
this.Endpoint = endpoint;
this.DeploymentName = deploymentName;
this.ApiKey = apiKey;
this.ModelId = modelId;
}
public string Endpoint { get; }
public string DeploymentName { get; }
public string ApiKey { get; }
public string? ModelId { get; }
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContractExtension.cs
using System;
using System.Collections.Generic;
using Azure.AI.OpenAI;
using Json.Schema;
using Json.Schema.Generation;
namespace AutoGen.OpenAI.Extension;
public static class FunctionContractExtension
{
/// <summary>
/// Convert a <see cref="FunctionContract"/> to a <see cref="FunctionDefinition"/> that can be used in gpt funciton call.
/// </summary>
/// <param name="functionContract">function contract</param>
/// <returns><see cref="FunctionDefinition"/></returns>
public static FunctionDefinition ToOpenAIFunctionDefinition(this FunctionContract functionContract)
{
var functionDefinition = new FunctionDefinition
{
Name = functionContract.Name,
Description = functionContract.Description,
};
var requiredParameterNames = new List<string>();
var propertiesSchemas = new Dictionary<string, JsonSchema>();
var propertySchemaBuilder = new JsonSchemaBuilder().Type(SchemaValueType.Object);
foreach (var param in functionContract.Parameters ?? [])
{
if (param.Name is null)
{
throw new InvalidOperationException("Parameter name cannot be null");
}
var schemaBuilder = new JsonSchemaBuilder().FromType(param.ParameterType ?? throw new ArgumentNullException(nameof(param.ParameterType)));
if (param.Description != null)
{
schemaBuilder = schemaBuilder.Description(param.Description);
}
if (param.IsRequired)
{
requiredParameterNames.Add(param.Name);
}
var schema = schemaBuilder.Build();
propertiesSchemas[param.Name] = schema;
}
propertySchemaBuilder = propertySchemaBuilder.Properties(propertiesSchemas);
propertySchemaBuilder = propertySchemaBuilder.Required(requiredParameterNames);
var option = new System.Text.Json.JsonSerializerOptions()
{
PropertyNamingPolicy = System.Text.Json.JsonNamingPolicy.CamelCase
};
functionDefinition.Parameters = BinaryData.FromObjectAsJson(propertySchemaBuilder.Build(), option);
return functionDefinition;
}
}

View File

@@ -0,0 +1,228 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// MessageExtension.cs
using System;
using System.Collections.Generic;
using System.Linq;
using Azure.AI.OpenAI;
namespace AutoGen.OpenAI;
public static class MessageExtension
{
public static string TEXT_CONTENT_TYPE = "text";
public static string IMAGE_CONTENT_TYPE = "image";
public static ChatRequestUserMessage ToChatRequestUserMessage(this Message message)
{
if (message.Value is ChatRequestUserMessage message1)
{
return message1;
}
else if (message?.Metadata is { Count: > 0 })
{
var itemList = new List<ChatMessageContentItem>();
foreach (var item in message.Metadata)
{
if (item.Key == TEXT_CONTENT_TYPE && item.Value is string txt)
{
itemList.Add(new ChatMessageTextContentItem(txt));
}
else if (item.Key == IMAGE_CONTENT_TYPE && item.Value is string url)
{
itemList.Add(new ChatMessageImageContentItem(new Uri(url)));
}
}
if (itemList.Count > 0)
{
return new ChatRequestUserMessage(itemList);
}
else
{
throw new ArgumentException("Content is null and metadata is null");
}
}
else if (!string.IsNullOrEmpty(message?.Content))
{
return new ChatRequestUserMessage(message!.Content);
}
throw new ArgumentException("Content is null and metadata is null");
}
public static IEnumerable<ChatRequestMessage> ToOpenAIChatRequestMessage(this IAgent agent, IMessage message)
{
if (message is IMessage<ChatRequestMessage> oaiMessage)
{
// short-circuit
return [oaiMessage.Content];
}
if (message.From != agent.Name)
{
if (message is TextMessage textMessage)
{
if (textMessage.Role == Role.System)
{
var msg = new ChatRequestSystemMessage(textMessage.Content);
return [msg];
}
else
{
var msg = new ChatRequestUserMessage(textMessage.Content);
return [msg];
}
}
else if (message is ImageMessage imageMessage)
{
// multi-modal
var msg = new ChatRequestUserMessage(new ChatMessageImageContentItem(new Uri(imageMessage.Url)));
return [msg];
}
else if (message is ToolCallMessage)
{
throw new ArgumentException($"ToolCallMessage is not supported when message.From is not the same with agent");
}
else if (message is ToolCallResultMessage toolCallResult)
{
return toolCallResult.ToolCalls.Select(m =>
{
var msg = new ChatRequestToolMessage(m.Result, m.FunctionName);
return msg;
});
}
else if (message is MultiModalMessage multiModalMessage)
{
var messageContent = multiModalMessage.Content.Select<IMessage, ChatMessageContentItem>(m =>
{
return m switch
{
TextMessage textMessage => new ChatMessageTextContentItem(textMessage.Content),
ImageMessage imageMessage => new ChatMessageImageContentItem(new Uri(imageMessage.Url)),
_ => throw new ArgumentException($"Unknown message type: {m.GetType()}")
};
});
var msg = new ChatRequestUserMessage(messageContent);
return [msg];
}
else if (message is AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage)
{
// convert as user message
var resultMessage = aggregateMessage.Message2;
return resultMessage.ToolCalls.Select(m => new ChatRequestUserMessage(m.Result));
}
else if (message is Message msg)
{
if (msg.Role == Role.System)
{
var systemMessage = new ChatRequestSystemMessage(msg.Content ?? string.Empty);
return [systemMessage];
}
else if (msg.FunctionName is null && msg.FunctionArguments is null)
{
var userMessage = msg.ToChatRequestUserMessage();
return [userMessage];
}
else if (msg.FunctionName is not null && msg.FunctionArguments is not null && msg.Content is not null)
{
if (msg.Role == Role.Function)
{
return [new ChatRequestFunctionMessage(msg.FunctionName, msg.Content)];
}
else
{
return [new ChatRequestUserMessage(msg.Content)];
}
}
else
{
var userMessage = new ChatRequestUserMessage(msg.Content ?? throw new ArgumentException("Content is null"));
return [userMessage];
}
}
else
{
throw new ArgumentException($"Unknown message type: {message.GetType()}");
}
}
else
{
if (message is TextMessage textMessage)
{
if (textMessage.Role == Role.System)
{
throw new ArgumentException("System message is not supported when message.From is the same with agent");
}
return [new ChatRequestAssistantMessage(textMessage.Content)];
}
else if (message is ToolCallMessage toolCallMessage)
{
var assistantMessage = new ChatRequestAssistantMessage(string.Empty);
var toolCalls = toolCallMessage.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
foreach (var tc in toolCalls)
{
assistantMessage.ToolCalls.Add(tc);
}
return [assistantMessage];
}
else if (message is AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage)
{
var toolCallMessage1 = aggregateMessage.Message1;
var toolCallResultMessage = aggregateMessage.Message2;
var assistantMessage = new ChatRequestAssistantMessage(string.Empty);
var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
foreach (var tc in toolCalls)
{
assistantMessage.ToolCalls.Add(tc);
}
var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
// return assistantMessage and tool call result messages
var messages = new List<ChatRequestMessage> { assistantMessage };
messages.AddRange(toolCallResults);
return messages;
}
else if (message is Message msg)
{
if (msg.FunctionArguments is not null && msg.FunctionName is not null && msg.Content is not null)
{
var assistantMessage = new ChatRequestAssistantMessage(msg.Content);
assistantMessage.FunctionCall = new FunctionCall(msg.FunctionName, msg.FunctionArguments);
var functionCallMessage = new ChatRequestFunctionMessage(msg.FunctionName, msg.Content);
return [assistantMessage, functionCallMessage];
}
else
{
if (msg.Role == Role.Function)
{
return [new ChatRequestFunctionMessage(msg.FunctionName!, msg.Content!)];
}
else
{
var assistantMessage = new ChatRequestAssistantMessage(msg.Content!);
if (msg.FunctionName is not null && msg.FunctionArguments is not null)
{
assistantMessage.FunctionCall = new FunctionCall(msg.FunctionName, msg.FunctionArguments);
}
return [assistantMessage];
}
}
}
else
{
throw new ArgumentException($"Unknown message type: {message.GetType()}");
}
}
}
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OpenAIAgentExtension.cs
namespace AutoGen.OpenAI.Extension;
public static class OpenAIAgentExtension
{
/// <summary>
/// Register an <see cref="OpenAIChatRequestMessageConnector"/> to the <see cref="OpenAIChatAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="OpenAIChatRequestMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<OpenAIChatAgent> RegisterMessageConnector(
this OpenAIChatAgent agent, OpenAIChatRequestMessageConnector? connector = null)
{
if (connector == null)
{
connector = new OpenAIChatRequestMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector);
}
/// <summary>
/// Register an <see cref="OpenAIChatRequestMessageConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="OpenAIChatAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="OpenAIChatRequestMessageConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<OpenAIChatAgent> RegisterMessageConnector(
this MiddlewareStreamingAgent<OpenAIChatAgent> agent, OpenAIChatRequestMessageConnector? connector = null)
{
if (connector == null)
{
connector = new OpenAIChatRequestMessageConnector();
}
return agent.RegisterStreamingMiddleware(connector);
}
}

View File

@@ -0,0 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GlobalUsing.cs
global using AutoGen.Core;

View File

@@ -0,0 +1,445 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OpenAIChatRequestMessageConnector.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
namespace AutoGen.OpenAI;
/// <summary>
/// This middleware converts the incoming <see cref="IMessage"/> to <see cref="IMessage{ChatRequestMessage}" /> where T is <see cref="ChatRequestMessage"/> before sending to agent. And converts the output <see cref="ChatResponseMessage"/> to <see cref="IMessage"/> after receiving from agent.
/// <para>Supported <see cref="IMessage"/> are</para>
/// <para>- <see cref="TextMessage"/></para>
/// <para>- <see cref="ImageMessage"/></para>
/// <para>- <see cref="MultiModalMessage"/></para>
/// <para>- <see cref="ToolCallMessage"/></para>
/// <para>- <see cref="ToolCallResultMessage"/></para>
/// <para>- <see cref="Message"/></para>
/// <para>- <see cref="IMessage{ChatRequestMessage}"/> where T is <see cref="ChatRequestMessage"/></para>
/// <para>- <see cref="AggregateMessage{TMessage1, TMessage2}"/> where TMessage1 is <see cref="ToolCallMessage"/> and TMessage2 is <see cref="ToolCallResultMessage"/></para>
/// </summary>
public class OpenAIChatRequestMessageConnector : IMiddleware, IStreamingMiddleware
{
private bool strictMode = false;
public OpenAIChatRequestMessageConnector(bool strictMode = false)
{
this.strictMode = strictMode;
}
public string? Name => nameof(OpenAIChatRequestMessageConnector);
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var chatMessages = ProcessIncomingMessages(agent, context.Messages)
.Select(m => new MessageEnvelope<ChatRequestMessage>(m));
var reply = await agent.GenerateReplyAsync(chatMessages, context.Options, cancellationToken);
return PostProcessMessage(reply);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(
MiddlewareContext context,
IStreamingAgent agent,
CancellationToken cancellationToken = default)
{
return InvokeStreamingAsync(context, agent, cancellationToken);
}
private async IAsyncEnumerable<IStreamingMessage> InvokeStreamingAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
var chatMessages = ProcessIncomingMessages(agent, context.Messages)
.Select(m => new MessageEnvelope<ChatRequestMessage>(m));
var streamingReply = await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken);
string? currentToolName = null;
await foreach (var reply in streamingReply)
{
if (reply is IStreamingMessage<StreamingChatCompletionsUpdate> update)
{
if (update.Content.FunctionName is string functionName)
{
currentToolName = functionName;
}
else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate && toolCallUpdate.Name is string toolCallName)
{
currentToolName = toolCallName;
}
var postProcessMessage = PostProcessStreamingMessage(update, currentToolName);
if (postProcessMessage != null)
{
yield return postProcessMessage;
}
}
else
{
yield return reply;
}
}
}
public IMessage PostProcessMessage(IMessage message)
{
return message switch
{
TextMessage => message,
ImageMessage => message,
MultiModalMessage => message,
ToolCallMessage => message,
ToolCallResultMessage => message,
Message => message,
AggregateMessage<ToolCallMessage, ToolCallResultMessage> => message,
IMessage<ChatResponseMessage> m => PostProcessMessage(m),
_ => throw new InvalidOperationException("The type of message is not supported. Must be one of TextMessage, ImageMessage, MultiModalMessage, ToolCallMessage, ToolCallResultMessage, Message, IMessage<ChatRequestMessage>, AggregateMessage<ToolCallMessage, ToolCallResultMessage>"),
};
}
public IStreamingMessage? PostProcessStreamingMessage(IStreamingMessage<StreamingChatCompletionsUpdate> update, string? currentToolName)
{
if (update.Content.ContentUpdate is string contentUpdate)
{
// text message
return new TextMessageUpdate(Role.Assistant, contentUpdate, from: update.From);
}
else if (update.Content.FunctionName is string functionName)
{
return new ToolCallMessageUpdate(functionName, string.Empty, from: update.From);
}
else if (update.Content.FunctionArgumentsUpdate is string functionArgumentsUpdate && currentToolName is string)
{
return new ToolCallMessageUpdate(currentToolName, functionArgumentsUpdate, from: update.From);
}
else if (update.Content.ToolCallUpdate is StreamingFunctionToolCallUpdate tooCallUpdate && currentToolName is string)
{
return new ToolCallMessageUpdate(tooCallUpdate.Name ?? currentToolName, tooCallUpdate.ArgumentsUpdate, from: update.From);
}
else
{
return null;
}
}
private IMessage PostProcessMessage(IMessage<ChatResponseMessage> message)
{
var chatResponseMessage = message.Content;
if (chatResponseMessage.Content is string content)
{
return new TextMessage(Role.Assistant, content, message.From);
}
if (chatResponseMessage.FunctionCall is FunctionCall functionCall)
{
return new ToolCallMessage(functionCall.Name, functionCall.Arguments, message.From);
}
if (chatResponseMessage.ToolCalls.Where(tc => tc is ChatCompletionsFunctionToolCall).Any())
{
var functionToolCalls = chatResponseMessage.ToolCalls
.Where(tc => tc is ChatCompletionsFunctionToolCall)
.Select(tc => (ChatCompletionsFunctionToolCall)tc);
var toolCalls = functionToolCalls.Select(tc => new ToolCall(tc.Name, tc.Arguments));
return new ToolCallMessage(toolCalls, message.From);
}
throw new InvalidOperationException("Invalid ChatResponseMessage");
}
public IEnumerable<ChatRequestMessage> ProcessIncomingMessages(IAgent agent, IEnumerable<IMessage> messages)
{
return messages.SelectMany(m =>
{
if (m.From == null)
{
return ProcessIncomingMessagesWithEmptyFrom(m);
}
else if (m.From == agent.Name)
{
return ProcessIncomingMessagesForSelf(m);
}
else
{
return ProcessIncomingMessagesForOther(m);
}
});
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(IMessage message)
{
return message switch
{
TextMessage textMessage => ProcessIncomingMessagesForSelf(textMessage),
ImageMessage imageMessage => ProcessIncomingMessagesForSelf(imageMessage),
MultiModalMessage multiModalMessage => ProcessIncomingMessagesForSelf(multiModalMessage),
ToolCallMessage toolCallMessage => ProcessIncomingMessagesForSelf(toolCallMessage),
ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForSelf(toolCallResultMessage),
Message msg => ProcessIncomingMessagesForSelf(msg),
IMessage<ChatRequestMessage> crm => ProcessIncomingMessagesForSelf(crm),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => ProcessIncomingMessagesForSelf(aggregateMessage),
_ => throw new NotImplementedException(),
};
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(IMessage message)
{
return message switch
{
TextMessage textMessage => ProcessIncomingMessagesWithEmptyFrom(textMessage),
ImageMessage imageMessage => ProcessIncomingMessagesWithEmptyFrom(imageMessage),
MultiModalMessage multiModalMessage => ProcessIncomingMessagesWithEmptyFrom(multiModalMessage),
ToolCallMessage toolCallMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallMessage),
ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesWithEmptyFrom(toolCallResultMessage),
Message msg => ProcessIncomingMessagesWithEmptyFrom(msg),
IMessage<ChatRequestMessage> crm => ProcessIncomingMessagesWithEmptyFrom(crm),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => ProcessIncomingMessagesWithEmptyFrom(aggregateMessage),
_ => throw new NotImplementedException(),
};
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(IMessage message)
{
return message switch
{
TextMessage textMessage => ProcessIncomingMessagesForOther(textMessage),
ImageMessage imageMessage => ProcessIncomingMessagesForOther(imageMessage),
MultiModalMessage multiModalMessage => ProcessIncomingMessagesForOther(multiModalMessage),
ToolCallMessage toolCallMessage => ProcessIncomingMessagesForOther(toolCallMessage),
ToolCallResultMessage toolCallResultMessage => ProcessIncomingMessagesForOther(toolCallResultMessage),
Message msg => ProcessIncomingMessagesForOther(msg),
IMessage<ChatRequestMessage> crm => ProcessIncomingMessagesForOther(crm),
AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage => ProcessIncomingMessagesForOther(aggregateMessage),
_ => throw new NotImplementedException(),
};
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(TextMessage message)
{
if (message.Role == Role.System)
{
return new[] { new ChatRequestSystemMessage(message.Content) };
}
else
{
return new[] { new ChatRequestAssistantMessage(message.Content) };
}
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(ImageMessage _)
{
return [new ChatRequestAssistantMessage("// Image Message is not supported")];
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(MultiModalMessage _)
{
return [new ChatRequestAssistantMessage("// MultiModal Message is not supported")];
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(ToolCallMessage message)
{
var toolCall = message.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
var chatRequestMessage = new ChatRequestAssistantMessage(string.Empty);
foreach (var tc in toolCall)
{
chatRequestMessage.ToolCalls.Add(tc);
}
return new[] { chatRequestMessage };
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(ToolCallResultMessage message)
{
return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(Message message)
{
if (message.Role == Role.System)
{
return new[] { new ChatRequestSystemMessage(message.Content) };
}
else if (message.Content is string content && content is { Length: > 0 })
{
if (message.FunctionName is null)
{
return new[] { new ChatRequestAssistantMessage(message.Content) };
}
else
{
return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
}
}
else if (message.FunctionName is string functionName)
{
var msg = new ChatRequestAssistantMessage(content: null)
{
FunctionCall = new FunctionCall(functionName, message.FunctionArguments)
};
return new[]
{
msg,
};
}
else
{
throw new InvalidOperationException("Invalid Message as message from self.");
}
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(IMessage<ChatRequestMessage> message)
{
return new[] { message.Content };
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForSelf(AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage)
{
var toolCallMessage1 = aggregateMessage.Message1;
var toolCallResultMessage = aggregateMessage.Message2;
var assistantMessage = new ChatRequestAssistantMessage(string.Empty);
var toolCalls = toolCallMessage1.ToolCalls.Select(tc => new ChatCompletionsFunctionToolCall(tc.FunctionName, tc.FunctionName, tc.FunctionArguments));
foreach (var tc in toolCalls)
{
assistantMessage.ToolCalls.Add(tc);
}
var toolCallResults = toolCallResultMessage.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
// return assistantMessage and tool call result messages
var messages = new List<ChatRequestMessage> { assistantMessage };
messages.AddRange(toolCallResults);
return messages;
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(TextMessage message)
{
if (message.Role == Role.System)
{
return new[] { new ChatRequestSystemMessage(message.Content) };
}
else
{
return new[] { new ChatRequestUserMessage(message.Content) };
}
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(ImageMessage message)
{
return new[] { new ChatRequestUserMessage([
new ChatMessageImageContentItem(new Uri(message.Url)),
])};
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(MultiModalMessage message)
{
IEnumerable<ChatMessageContentItem> items = message.Content.Select<IMessage, ChatMessageContentItem>(ci => ci switch
{
TextMessage text => new ChatMessageTextContentItem(text.Content),
ImageMessage image => new ChatMessageImageContentItem(new Uri(image.Url)),
_ => throw new NotImplementedException(),
});
return new[] { new ChatRequestUserMessage(items) };
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(ToolCallMessage msg)
{
throw new ArgumentException("ToolCallMessage is not supported when message.From is not the same with agent");
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(ToolCallResultMessage message)
{
return message.ToolCalls.Select(tc => new ChatRequestToolMessage(tc.Result, tc.FunctionName));
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(Message message)
{
if (message.Role == Role.System)
{
return new[] { new ChatRequestSystemMessage(message.Content) };
}
else if (message.Content is string content && content is { Length: > 0 })
{
if (message.FunctionName is not null)
{
return new[] { new ChatRequestToolMessage(content, message.FunctionName) };
}
return new[] { new ChatRequestUserMessage(message.Content) };
}
else if (message.FunctionName is string _)
{
return new[]
{
new ChatRequestUserMessage("// Message type is not supported"),
};
}
else
{
throw new InvalidOperationException("Invalid Message as message from other.");
}
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(IMessage<ChatRequestMessage> message)
{
return new[] { message.Content };
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesForOther(AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage)
{
// convert as user message
var resultMessage = aggregateMessage.Message2;
return resultMessage.ToolCalls.Select(tc => new ChatRequestUserMessage(tc.Result));
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(TextMessage message)
{
return ProcessIncomingMessagesForOther(message);
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(ImageMessage message)
{
return ProcessIncomingMessagesForOther(message);
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(MultiModalMessage message)
{
return ProcessIncomingMessagesForOther(message);
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(ToolCallMessage message)
{
return ProcessIncomingMessagesForSelf(message);
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(ToolCallResultMessage message)
{
return ProcessIncomingMessagesForOther(message);
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(Message message)
{
return ProcessIncomingMessagesForOther(message);
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(IMessage<ChatRequestMessage> message)
{
return new[] { message.Content };
}
private IEnumerable<ChatRequestMessage> ProcessIncomingMessagesWithEmptyFrom(AggregateMessage<ToolCallMessage, ToolCallResultMessage> aggregateMessage)
{
return ProcessIncomingMessagesForOther(aggregateMessage);
}
}

View File

@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// OpenAIConfig.cs
namespace AutoGen.OpenAI;
public class OpenAIConfig : ILLMConfig
{
public OpenAIConfig(string apiKey, string modelId)
{
this.ApiKey = apiKey;
this.ModelId = modelId;
}
public string ApiKey { get; }
public string ModelId { get; }
}

View File

@@ -0,0 +1,27 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>AutoGen.SemanticKernel</RootNamespace>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen.SemanticKernel</Title>
<Description>
This package contains the semantic kernel integration for AutoGen
</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="$(AzureOpenAIVersion)" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// KernelExtension.cs
using Microsoft.SemanticKernel;
namespace AutoGen.SemanticKernel.Extension;
public static class KernelExtension
{
public static SemanticKernelAgent ToSemanticKernelAgent(this Kernel kernel, string name, string systemMessage = "You are a helpful AI assistant", PromptExecutionSettings? settings = null)
{
return new SemanticKernelAgent(kernel, name, systemMessage, settings);
}
}

View File

@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SemanticKernelAgentExtension.cs
namespace AutoGen.SemanticKernel.Extension;
public static class SemanticKernelAgentExtension
{
/// <summary>
/// Register an <see cref="SemanticKernelChatMessageContentConnector"/> to the <see cref="SemanticKernelAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="SemanticKernelChatMessageContentConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<SemanticKernelAgent> RegisterMessageConnector(
this SemanticKernelAgent agent, SemanticKernelChatMessageContentConnector? connector = null)
{
if (connector == null)
{
connector = new SemanticKernelChatMessageContentConnector();
}
return agent.RegisterStreamingMiddleware(connector);
}
/// <summary>
/// Register an <see cref="SemanticKernelChatMessageContentConnector"/> to the <see cref="MiddlewareAgent{T}"/> where T is <see cref="SemanticKernelAgent"/>
/// </summary>
/// <param name="connector">the connector to use. If null, a new instance of <see cref="SemanticKernelChatMessageContentConnector"/> will be created.</param>
public static MiddlewareStreamingAgent<SemanticKernelAgent> RegisterMessageConnector(
this MiddlewareStreamingAgent<SemanticKernelAgent> agent, SemanticKernelChatMessageContentConnector? connector = null)
{
if (connector == null)
{
connector = new SemanticKernelChatMessageContentConnector();
}
return agent.RegisterStreamingMiddleware(connector);
}
}

View File

@@ -0,0 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GlobalUsing.cs
global using AutoGen.Core;

View File

@@ -0,0 +1,260 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SemanticKernelChatMessageContentConnector.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
namespace AutoGen.SemanticKernel;
/// <summary>
/// This middleware converts the incoming <see cref="IMessage"/> to <see cref="ChatMessageContent"/> before passing to agent.
/// And converts the reply message from <see cref="ChatMessageContent"/> to <see cref="IMessage"/> before returning to the caller.
///
/// <para>requirement for agent</para>
/// <para>- Input message type: <see cref="IMessage{T}"/> where T is <see cref="ChatMessageContent"/></para>
/// <para>- Reply message type: <see cref="IMessage{T}"/> where T is <see cref="ChatMessageContent"/></para>
/// <para>- (streaming) Reply message type: <see cref="IMessage{T}"/> where T is <see cref="StreamingChatMessageContent"/></para>
///
/// This middleware supports the following message types:
/// <para>- <see cref="TextMessage"/></para>
/// <para>- <see cref="ImageMessage"/></para>
/// <para>- <see cref="MultiModalMessage"/></para>
///
/// This middleware returns the following message types:
/// <para>- <see cref="TextMessage"/></para>
/// <para>- <see cref="ImageMessage"/></para>
/// <para>- <see cref="MultiModalMessage"/></para>
/// <para>- (streaming) <see cref="TextMessageUpdate"/></para>
/// </summary>
public class SemanticKernelChatMessageContentConnector : IMiddleware, IStreamingMiddleware
{
public string? Name => nameof(SemanticKernelChatMessageContentConnector);
public async Task<IMessage> InvokeAsync(MiddlewareContext context, IAgent agent, CancellationToken cancellationToken = default)
{
var messages = context.Messages;
var chatMessageContents = ProcessMessage(messages, agent)
.Select(m => new MessageEnvelope<ChatMessageContent>(m));
var reply = await agent.GenerateReplyAsync(chatMessageContents, context.Options, cancellationToken);
return PostProcessMessage(reply);
}
public Task<IAsyncEnumerable<IStreamingMessage>> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default)
{
return Task.FromResult(InvokeStreamingAsync(context, agent, cancellationToken));
}
private async IAsyncEnumerable<IStreamingMessage> InvokeStreamingAsync(
MiddlewareContext context,
IStreamingAgent agent,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
var chatMessageContents = ProcessMessage(context.Messages, agent)
.Select(m => new MessageEnvelope<ChatMessageContent>(m));
await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken))
{
yield return PostProcessStreamingMessage(reply);
}
}
private IMessage PostProcessMessage(IMessage input)
{
return input switch
{
IMessage<ChatMessageContent> messageEnvelope => PostProcessMessage(messageEnvelope),
_ => input,
};
}
private IStreamingMessage PostProcessStreamingMessage(IStreamingMessage input)
{
return input switch
{
IStreamingMessage<StreamingChatMessageContent> streamingMessage => PostProcessMessage(streamingMessage),
IMessage msg => PostProcessMessage(msg),
_ => input,
};
}
private IMessage PostProcessMessage(IMessage<ChatMessageContent> messageEnvelope)
{
var chatMessageContent = messageEnvelope.Content;
var items = chatMessageContent.Items.Select<KernelContent, IMessage>(i => i switch
{
TextContent txt => new TextMessage(Role.Assistant, txt.Text!, messageEnvelope.From),
ImageContent img when img.Uri is Uri uri => new ImageMessage(Role.Assistant, uri.ToString(), from: messageEnvelope.From),
ImageContent img when img.Uri is null => throw new InvalidOperationException("ImageContent.Uri is null"),
_ => throw new InvalidOperationException("Unsupported content type"),
});
if (items.Count() == 1)
{
return items.First();
}
else
{
return new MultiModalMessage(Role.Assistant, items, from: messageEnvelope.From);
}
}
private IStreamingMessage PostProcessMessage(IStreamingMessage<StreamingChatMessageContent> streamingMessage)
{
var chatMessageContent = streamingMessage.Content;
if (chatMessageContent.ChoiceIndex > 0)
{
throw new InvalidOperationException("Only one choice is supported in streaming response");
}
return new TextMessageUpdate(Role.Assistant, chatMessageContent.Content, streamingMessage.From);
}
private IEnumerable<ChatMessageContent> ProcessMessage(IEnumerable<IMessage> messages, IAgent agent)
{
return messages.SelectMany(m =>
{
if (m is IMessage<ChatMessageContent> chatMessageContent)
{
return [chatMessageContent.Content];
}
if (m.From == agent.Name)
{
return ProcessMessageForSelf(m);
}
else
{
return ProcessMessageForOthers(m);
}
});
}
private IEnumerable<ChatMessageContent> ProcessMessageForSelf(IMessage message)
{
return message switch
{
TextMessage textMessage => ProcessMessageForSelf(textMessage),
MultiModalMessage multiModalMessage => ProcessMessageForSelf(multiModalMessage),
Message m => ProcessMessageForSelf(m),
_ => throw new System.NotImplementedException(),
};
}
private IEnumerable<ChatMessageContent> ProcessMessageForOthers(IMessage message)
{
return message switch
{
TextMessage textMessage => ProcessMessageForOthers(textMessage),
MultiModalMessage multiModalMessage => ProcessMessageForOthers(multiModalMessage),
ImageMessage imageMessage => ProcessMessageForOthers(imageMessage),
Message m => ProcessMessageForOthers(m),
_ => throw new InvalidOperationException("unsupported message type, only support TextMessage, ImageMessage, MultiModalMessage and Message."),
};
}
private IEnumerable<ChatMessageContent> ProcessMessageForSelf(TextMessage message)
{
if (message.Role == Role.System)
{
return [new ChatMessageContent(AuthorRole.System, message.Content)];
}
else
{
return [new ChatMessageContent(AuthorRole.Assistant, message.Content)];
}
}
private IEnumerable<ChatMessageContent> ProcessMessageForOthers(TextMessage message)
{
if (message.Role == Role.System)
{
return [new ChatMessageContent(AuthorRole.System, message.Content)];
}
else
{
return [new ChatMessageContent(AuthorRole.User, message.Content)];
}
}
private IEnumerable<ChatMessageContent> ProcessMessageForOthers(ImageMessage message)
{
var imageContent = new ImageContent(new Uri(message.Url));
var collectionItems = new ChatMessageContentItemCollection();
collectionItems.Add(imageContent);
return [new ChatMessageContent(AuthorRole.User, collectionItems)];
}
private IEnumerable<ChatMessageContent> ProcessMessageForSelf(MultiModalMessage message)
{
throw new System.InvalidOperationException("MultiModalMessage is not supported in the semantic kernel if it's from self.");
}
private IEnumerable<ChatMessageContent> ProcessMessageForOthers(MultiModalMessage message)
{
var collections = new ChatMessageContentItemCollection();
foreach (var item in message.Content)
{
if (item is TextMessage textContent)
{
collections.Add(new TextContent(textContent.Content));
}
else if (item is ImageMessage imageContent)
{
collections.Add(new ImageContent(new Uri(imageContent.Url)));
}
else
{
throw new InvalidOperationException($"Unsupported message type: {item.GetType().Name}");
}
}
return [new ChatMessageContent(AuthorRole.User, collections)];
}
private IEnumerable<ChatMessageContent> ProcessMessageForSelf(Message message)
{
if (message.Role == Role.System)
{
return [new ChatMessageContent(AuthorRole.System, message.Content)];
}
else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null)
{
return [new ChatMessageContent(AuthorRole.Assistant, message.Content)];
}
else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null)
{
throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from self.");
}
else
{
throw new System.InvalidOperationException("Unsupported message type");
}
}
private IEnumerable<ChatMessageContent> ProcessMessageForOthers(Message message)
{
if (message.Role == Role.System)
{
return [new ChatMessageContent(AuthorRole.System, message.Content)];
}
else if (message.Content is string && message.FunctionName is null && message.FunctionArguments is null)
{
return [new ChatMessageContent(AuthorRole.User, message.Content)];
}
else if (message.Content is null && message.FunctionName is not null && message.FunctionArguments is not null)
{
throw new System.InvalidOperationException("Function call is not supported in the semantic kernel if it's from others.");
}
else
{
throw new System.InvalidOperationException("Unsupported message type");
}
}
}

View File

@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SemanticKernelAgent.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
namespace AutoGen.SemanticKernel;
/// <summary>
/// Semantic Kernel Agent
/// <listheader>Income message could be one of the following type:</listheader>
/// <list type="bullet">
/// <item><see cref="IMessage{T}"/> where T is <see cref="ChatMessageContent"/></item>
/// </list>
///
/// <listheader>Return message could be one of the following type:</listheader>
/// <list type="bullet">
/// <item><see cref="IMessage{T}"/> where T is <see cref="ChatMessageContent"/></item>
/// <item>(streaming) <see cref="IMessage{T}"/> where T is <see cref="StreamingChatMessageContent"/></item>
/// </list>
///
/// <para>To support more AutoGen built-in <see cref="IMessage"/>, register with <see cref="SemanticKernelChatMessageContentConnector"/>.</para>
/// </summary>
public class SemanticKernelAgent : IStreamingAgent
{
private readonly Kernel _kernel;
private readonly string _systemMessage;
private readonly PromptExecutionSettings? _settings;
public SemanticKernelAgent(
Kernel kernel,
string name,
string systemMessage = "You are a helpful AI assistant",
PromptExecutionSettings? settings = null)
{
_kernel = kernel;
this.Name = name;
_systemMessage = systemMessage;
_settings = settings;
}
public string Name { get; }
public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default)
{
var chatHistory = BuildChatHistory(messages);
var option = BuildOption(options);
var chatService = _kernel.GetRequiredService<IChatCompletionService>();
var reply = await chatService.GetChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken);
if (reply.Count > 1)
{
throw new InvalidOperationException("ResultsPerPrompt greater than 1 is not supported in this semantic kernel agent");
}
return new MessageEnvelope<ChatMessageContent>(reply.First(), from: this.Name);
}
public async Task<IAsyncEnumerable<IStreamingMessage>> GenerateStreamingReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? options = null,
CancellationToken cancellationToken = default)
{
var chatHistory = BuildChatHistory(messages);
var option = BuildOption(options);
var chatService = _kernel.GetRequiredService<IChatCompletionService>();
var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken);
return ProcessMessage(response);
}
private ChatHistory BuildChatHistory(IEnumerable<IMessage> messages)
{
var chatMessageContents = ProcessMessage(messages);
// if there's no system message in chatMessageContents, add one to the beginning
if (!chatMessageContents.Any(c => c.Role == AuthorRole.System))
{
chatMessageContents = new[] { new ChatMessageContent(AuthorRole.System, _systemMessage) }.Concat(chatMessageContents);
}
return new ChatHistory(chatMessageContents);
}
private PromptExecutionSettings BuildOption(GenerateReplyOptions? options)
{
return _settings ?? new OpenAIPromptExecutionSettings
{
Temperature = options?.Temperature ?? 0.7f,
MaxTokens = options?.MaxToken ?? 1024,
StopSequences = options?.StopSequence,
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions,
ResultsPerPrompt = 1,
};
}
private async IAsyncEnumerable<IMessage> ProcessMessage(IAsyncEnumerable<StreamingChatMessageContent> response)
{
await foreach (var content in response)
{
if (content.ChoiceIndex > 0)
{
throw new InvalidOperationException("Only one choice is supported in streaming response");
}
yield return new MessageEnvelope<StreamingChatMessageContent>(content, from: this.Name);
}
}
private IEnumerable<ChatMessageContent> ProcessMessage(IEnumerable<IMessage> messages)
{
return messages.Select(m => m switch
{
IMessage<ChatMessageContent> cmc => cmc.Content,
_ => throw new ArgumentException("Invalid message type")
});
}
}

View File

@@ -0,0 +1,60 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IncludeBuildOutput>false</IncludeBuildOutput>
<!-- Do not include the generator as a lib dependency -->
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<!-- Do not include the generator as a lib dependency -->
<UserSecretsId>35954224-b94e-4024-b0ef-7ba7cf80c0d8</UserSecretsId>
<GetTargetPathDependsOn>$(GetTargetPathDependsOn);GetDependencyTargetPaths</GetTargetPathDependsOn>
<LaunchDebugger>false</LaunchDebugger>
<NoWarn>$(NoWarn);NU5128</NoWarn>
<DefineConstants Condition="'$(LaunchDebugger)' == 'true'">$(DefineConstants);LAUNCH_DEBUGGER</DefineConstants>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<Title>AutoGen.SourceGenerator</Title>
<Description>Source generator for AutoGen. This package provides type-safe function call to AutoGen agents.</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.CodeDom" Version="$(SystemCodeDomVersion)" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(MicrosoftCodeAnalysisVersion)" PrivateAssets="all" GeneratePathProperty="True" />
<PackageReference Include="Newtonsoft.Json" PrivateAssets="all" Version="13.0.1" GeneratePathProperty="true" />
</ItemGroup>
<ItemGroup>
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
<None Include="$(PkgNewtonsoft_Json)\lib\netstandard2.0\*.dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
</ItemGroup>
<Target Name="GetDependencyTargetPaths">
<ItemGroup>
<TargetPathWithTargetPlatformMoniker Include="$(PkgNewtonsoft_Json)\lib\netstandard2.0\*.dll" IncludeRuntimeDependency="false" />
</ItemGroup>
</Target>
<ItemGroup>
<None Update="Template\FunctionCallTemplate.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>FunctionCallTemplate.cs</LastGenOutput>
</None>
</ItemGroup>
<ItemGroup>
<Service Include="{508349b6-6b84-4df5-91f0-309beebad82d}" />
</ItemGroup>
<ItemGroup>
<Compile Update="Template\FunctionCallTemplate.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>FunctionCallTemplate.tt</DependentUpon>
</Compile>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,295 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// DocumentCommentExtension.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Xml.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
// copyright: https://github.com/DotNetAnalyzers/StyleCopAnalyzers/blob/master/StyleCop.Analyzers/StyleCop.Analyzers/Helpers/DocumentationCommentExtensions.cs#L17
namespace AutoGen.SourceGenerator
{
internal static class DocumentCommentExtension
{
public static bool IsMissingOrDefault(this SyntaxToken token)
{
return token.IsKind(SyntaxKind.None)
|| token.IsMissing;
}
public static string? GetParameterDescriptionFromDocumentationCommentTriviaSyntax(this DocumentationCommentTriviaSyntax documentationCommentTrivia, string parameterName)
{
var parameterElements = documentationCommentTrivia.Content.GetXmlElements("param");
var parameter = parameterElements.FirstOrDefault(element =>
{
var xml = XElement.Parse(element.ToString());
var nameAttribute = xml.Attribute("name");
return nameAttribute != null && nameAttribute.Value == parameterName;
});
if (parameter is not null)
{
var xml = XElement.Parse(parameter.ToString());
return xml.Nodes().OfType<XText>().FirstOrDefault()?.Value;
}
return null;
}
public static string? GetNamespaceNameFromClassDeclarationSyntax(this ClassDeclarationSyntax classDeclaration)
{
return classDeclaration.Parent is NamespaceDeclarationSyntax namespaceDeclarationSyntax ? namespaceDeclarationSyntax.Name.ToString()
: classDeclaration.Parent is FileScopedNamespaceDeclarationSyntax fileScopedNamespaceDeclarationSyntax ? fileScopedNamespaceDeclarationSyntax.Name.ToString()
: null;
}
public static DocumentationCommentTriviaSyntax? GetDocumentationCommentTriviaSyntax(this SyntaxNode node)
{
if (node == null)
{
return null;
}
foreach (var leadingTrivia in node.GetLeadingTrivia())
{
if (leadingTrivia.GetStructure() is DocumentationCommentTriviaSyntax structure)
{
return structure;
}
}
return null;
}
public static XmlNodeSyntax GetFirstXmlElement(this SyntaxList<XmlNodeSyntax> content, string elementName)
{
return content.GetXmlElements(elementName).FirstOrDefault();
}
public static IEnumerable<XmlNodeSyntax> GetXmlElements(this SyntaxList<XmlNodeSyntax> content, string elementName)
{
foreach (XmlNodeSyntax syntax in content)
{
if (syntax is XmlEmptyElementSyntax emptyElement)
{
if (string.Equals(elementName, emptyElement.Name.ToString(), StringComparison.Ordinal))
{
yield return emptyElement;
}
continue;
}
if (syntax is XmlElementSyntax elementSyntax)
{
if (string.Equals(elementName, elementSyntax.StartTag?.Name?.ToString(), StringComparison.Ordinal))
{
yield return elementSyntax;
}
continue;
}
}
}
public static T ReplaceExteriorTrivia<T>(this T node, SyntaxTrivia trivia)
where T : XmlNodeSyntax
{
// Make sure to include a space after the '///' characters.
SyntaxTrivia triviaWithSpace = SyntaxFactory.DocumentationCommentExterior(trivia.ToString() + " ");
return node.ReplaceTrivia(
node.DescendantTrivia(descendIntoTrivia: true).Where(i => i.IsKind(SyntaxKind.DocumentationCommentExteriorTrivia)),
(originalTrivia, rewrittenTrivia) => SelectExteriorTrivia(rewrittenTrivia, trivia, triviaWithSpace));
}
public static SyntaxList<XmlNodeSyntax> WithoutFirstAndLastNewlines(this SyntaxList<XmlNodeSyntax> summaryContent)
{
if (summaryContent.Count == 0)
{
return summaryContent;
}
if (!(summaryContent[0] is XmlTextSyntax firstSyntax))
{
return summaryContent;
}
if (!(summaryContent[summaryContent.Count - 1] is XmlTextSyntax lastSyntax))
{
return summaryContent;
}
SyntaxTokenList firstSyntaxTokens = firstSyntax.TextTokens;
int removeFromStart;
if (IsXmlNewLine(firstSyntaxTokens[0]))
{
removeFromStart = 1;
}
else
{
if (!IsXmlWhitespace(firstSyntaxTokens[0]))
{
return summaryContent;
}
if (!IsXmlNewLine(firstSyntaxTokens[1]))
{
return summaryContent;
}
removeFromStart = 2;
}
SyntaxTokenList lastSyntaxTokens = lastSyntax.TextTokens;
int removeFromEnd;
if (IsXmlNewLine(lastSyntaxTokens[lastSyntaxTokens.Count - 1]))
{
removeFromEnd = 1;
}
else
{
if (!IsXmlWhitespace(lastSyntaxTokens[lastSyntaxTokens.Count - 1]))
{
return summaryContent;
}
if (!IsXmlNewLine(lastSyntaxTokens[lastSyntaxTokens.Count - 2]))
{
return summaryContent;
}
removeFromEnd = 2;
}
for (int i = 0; i < removeFromStart; i++)
{
firstSyntaxTokens = firstSyntaxTokens.RemoveAt(0);
}
if (firstSyntax == lastSyntax)
{
lastSyntaxTokens = firstSyntaxTokens;
}
for (int i = 0; i < removeFromEnd; i++)
{
if (!lastSyntaxTokens.Any())
{
break;
}
lastSyntaxTokens = lastSyntaxTokens.RemoveAt(lastSyntaxTokens.Count - 1);
}
summaryContent = summaryContent.RemoveAt(summaryContent.Count - 1);
if (lastSyntaxTokens.Count != 0)
{
summaryContent = summaryContent.Add(lastSyntax.WithTextTokens(lastSyntaxTokens));
}
if (firstSyntax != lastSyntax)
{
summaryContent = summaryContent.RemoveAt(0);
if (firstSyntaxTokens.Count != 0)
{
summaryContent = summaryContent.Insert(0, firstSyntax.WithTextTokens(firstSyntaxTokens));
}
}
if (summaryContent.Count > 0)
{
// Make sure to remove the leading trivia
summaryContent = summaryContent.Replace(summaryContent[0], summaryContent[0].WithLeadingTrivia());
// Remove leading spaces (between the <para> start tag and the start of the paragraph content)
if (summaryContent[0] is XmlTextSyntax firstTextSyntax && firstTextSyntax.TextTokens.Count > 0)
{
SyntaxToken firstTextToken = firstTextSyntax.TextTokens[0];
string firstTokenText = firstTextToken.Text;
string trimmed = firstTokenText.TrimStart();
if (trimmed != firstTokenText)
{
SyntaxToken newFirstToken = SyntaxFactory.Token(
firstTextToken.LeadingTrivia,
firstTextToken.Kind(),
trimmed,
firstTextToken.ValueText.TrimStart(),
firstTextToken.TrailingTrivia);
summaryContent = summaryContent.Replace(firstTextSyntax, firstTextSyntax.ReplaceToken(firstTextToken, newFirstToken));
}
}
}
return summaryContent;
}
public static bool IsXmlNewLine(this SyntaxToken node)
{
return node.IsKind(SyntaxKind.XmlTextLiteralNewLineToken);
}
public static bool IsXmlWhitespace(this SyntaxToken node)
{
return node.IsKind(SyntaxKind.XmlTextLiteralToken)
&& string.IsNullOrWhiteSpace(node.Text);
}
/// <summary>
/// Adjust the leading and trailing trivia associated with <see cref="SyntaxKind.XmlTextLiteralNewLineToken"/>
/// tokens to ensure the formatter properly indents the exterior trivia.
/// </summary>
/// <typeparam name="T">The type of syntax node.</typeparam>
/// <param name="node">The syntax node to adjust tokens.</param>
/// <returns>A <see cref="SyntaxNode"/> equivalent to the input <paramref name="node"/>, adjusted by moving any
/// trailing trivia from <see cref="SyntaxKind.XmlTextLiteralNewLineToken"/> tokens to be leading trivia of the
/// following token.</returns>
public static T AdjustDocumentationCommentNewLineTrivia<T>(this T node)
where T : SyntaxNode
{
var tokensForAdjustment =
from token in node.DescendantTokens()
where token.IsKind(SyntaxKind.XmlTextLiteralNewLineToken)
where token.HasTrailingTrivia
let next = token.GetNextToken(includeZeroWidth: true, includeSkipped: true, includeDirectives: true, includeDocumentationComments: true)
where !next.IsMissingOrDefault()
select new KeyValuePair<SyntaxToken, SyntaxToken>(token, next);
Dictionary<SyntaxToken, SyntaxToken> replacements = new Dictionary<SyntaxToken, SyntaxToken>();
foreach (var pair in tokensForAdjustment)
{
replacements[pair.Key] = pair.Key.WithTrailingTrivia();
replacements[pair.Value] = pair.Value.WithLeadingTrivia(pair.Value.LeadingTrivia.InsertRange(0, pair.Key.TrailingTrivia));
}
return node.ReplaceTokens(replacements.Keys, (originalToken, rewrittenToken) => replacements[originalToken]);
}
public static XmlNameSyntax? GetName(this XmlNodeSyntax element)
{
return (element as XmlElementSyntax)?.StartTag?.Name
?? (element as XmlEmptyElementSyntax)?.Name;
}
private static SyntaxTrivia SelectExteriorTrivia(SyntaxTrivia rewrittenTrivia, SyntaxTrivia trivia, SyntaxTrivia triviaWithSpace)
{
// if the trivia had a trailing space, make sure to preserve it
if (rewrittenTrivia.ToString().EndsWith(" "))
{
return triviaWithSpace;
}
// otherwise the space is part of the leading trivia of the following token, so don't add an extra one to
// the exterior trivia
return trivia;
}
}
}

View File

@@ -0,0 +1,248 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionCallGenerator.cs
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Xml.Linq;
using AutoGen.SourceGenerator.Template;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using Newtonsoft.Json;
namespace AutoGen.SourceGenerator
{
[Generator]
public partial class FunctionCallGenerator : IIncrementalGenerator
{
private const string FUNCTION_CALL_ATTRIBUTION = "AutoGen.Core.FunctionAttribute";
public void Initialize(IncrementalGeneratorInitializationContext context)
{
#if LAUNCH_DEBUGGER
if (!System.Diagnostics.Debugger.IsAttached)
{
System.Diagnostics.Debugger.Launch();
}
#endif
var optionProvider = context.AnalyzerConfigOptionsProvider.Select((provider, ct) =>
{
var generateFunctionDefinitionContract = provider.GlobalOptions.TryGetValue("build_property.EnableContract", out var value) && value?.ToLowerInvariant() == "true";
return generateFunctionDefinitionContract;
});
// step 1
// filter syntax tree and search syntax node that satisfied the following conditions
// - is partial class
var partialClassSyntaxProvider = context.SyntaxProvider.CreateSyntaxProvider<PartialClassOutput?>(
(node, ct) =>
{
return node is ClassDeclarationSyntax classDeclarationSyntax && classDeclarationSyntax.Modifiers.Any(SyntaxKind.PartialKeyword);
},
(ctx, ct) =>
{
// first check if any method of the class has FunctionAttribution attribute
// if not, then return null
var filePath = ctx.Node.SyntaxTree.FilePath;
var fileName = Path.GetFileNameWithoutExtension(filePath);
var classDeclarationSyntax = ctx.Node as ClassDeclarationSyntax;
var nameSpace = classDeclarationSyntax?.Parent as NamespaceDeclarationSyntax;
var fullClassName = $"{nameSpace?.Name}.{classDeclarationSyntax!.Identifier}";
if (classDeclarationSyntax == null)
{
return null;
}
if (!classDeclarationSyntax.Members.Any(member => member.AttributeLists.Any(attributeList => attributeList.Attributes.Any(attribute =>
{
return ctx.SemanticModel.GetSymbolInfo(attribute).Symbol is IMethodSymbol methodSymbol && methodSymbol.ContainingType.ToDisplayString() == FUNCTION_CALL_ATTRIBUTION;
}))))
{
return null;
}
// collect methods that has FunctionAttribution attribute
var methodDeclarationSyntaxes = classDeclarationSyntax.Members.Where(member => member.AttributeLists.Any(attributeList => attributeList.Attributes.Any(attribute =>
{
return ctx.SemanticModel.GetSymbolInfo(attribute).Symbol is IMethodSymbol methodSymbol && methodSymbol.ContainingType.ToDisplayString() == FUNCTION_CALL_ATTRIBUTION;
})))
.Select(member => member as MethodDeclarationSyntax)
.Where(method => method != null);
var className = classDeclarationSyntax.Identifier.ToString();
var namespaceName = classDeclarationSyntax.GetNamespaceNameFromClassDeclarationSyntax();
var functionContracts = methodDeclarationSyntaxes.Select(method => CreateFunctionContract(method!, className, namespaceName));
return new PartialClassOutput(fullClassName, classDeclarationSyntax, functionContracts);
})
.Where(node => node != null)
.Collect();
var aggregateProvider = optionProvider.Combine(partialClassSyntaxProvider);
// step 2
context.RegisterSourceOutput(aggregateProvider,
(ctx, source) =>
{
var groups = source.Right.GroupBy(item => item!.FullClassName);
foreach (var group in groups)
{
var functionContracts = group.SelectMany(item => item!.FunctionContracts).ToArray();
var className = group.First()!.ClassDeclarationSyntax.Identifier.ToString();
var namespaceName = group.First()!.ClassDeclarationSyntax.GetNamespaceNameFromClassDeclarationSyntax() ?? string.Empty;
var functionTT = new FunctionCallTemplate
{
NameSpace = namespaceName,
ClassName = className,
FunctionContracts = functionContracts.ToArray(),
};
var functionSource = functionTT.TransformText();
var fileName = $"{className}.generated.cs";
ctx.AddSource(fileName, SourceText.From(functionSource, System.Text.Encoding.UTF8));
File.WriteAllText(Path.Combine(Path.GetTempPath(), fileName), functionSource);
}
if (source.Left)
{
var overallFunctionDefinition = source.Right.SelectMany(x => x!.FunctionContracts.Select(y => new { fullClassName = x.FullClassName, y = y }));
var overallFunctionDefinitionObject = overallFunctionDefinition.Select(
x => new
{
fullClassName = x.fullClassName,
functionDefinition = new
{
x.y.Name,
x.y.Description,
x.y.ReturnType,
Parameters = x.y.Parameters.Select(y => new
{
y.Name,
y.Description,
y.JsonType,
y.JsonItemType,
y.Type,
y.IsOptional,
y.DefaultValue,
}),
},
});
var json = JsonConvert.SerializeObject(overallFunctionDefinitionObject, formatting: Formatting.Indented);
// wrap json inside csharp block, as SG doesn't support generating non-source file
json = $@"/* <auto-generated> wrap json inside csharp block, as SG doesn't support generating non-source file
{json}
</auto-generated>*/";
ctx.AddSource("FunctionDefinition.json", SourceText.From(json, System.Text.Encoding.UTF8));
}
});
}
private class PartialClassOutput
{
public PartialClassOutput(string fullClassName, ClassDeclarationSyntax classDeclarationSyntax, IEnumerable<FunctionContract> functionContracts)
{
FullClassName = fullClassName;
ClassDeclarationSyntax = classDeclarationSyntax;
FunctionContracts = functionContracts;
}
public string FullClassName { get; }
public ClassDeclarationSyntax ClassDeclarationSyntax { get; }
public IEnumerable<FunctionContract> FunctionContracts { get; }
}
private FunctionContract CreateFunctionContract(MethodDeclarationSyntax method, string? className, string? namespaceName)
{
// get function_call attribute
var functionCallAttribute = method.AttributeLists.SelectMany(attributeList => attributeList.Attributes)
.FirstOrDefault(attribute => attribute.Name.ToString() == FUNCTION_CALL_ATTRIBUTION);
// get document string if exist
var documentationCommentTrivia = method.GetDocumentationCommentTriviaSyntax();
var functionName = method.Identifier.ToString();
var functionDescription = functionCallAttribute?.ArgumentList?.Arguments.FirstOrDefault(argument => argument.NameEquals?.Name.ToString() == "Description")?.Expression.ToString() ?? string.Empty;
if (string.IsNullOrEmpty(functionDescription))
{
// if functionDescription is empty, then try to get it from documentationCommentTrivia
// firstly, try getting from <summary> tag
var summary = documentationCommentTrivia?.Content.GetFirstXmlElement("summary");
if (summary is not null && XElement.Parse(summary.ToString()) is XElement element)
{
functionDescription = element.Nodes().OfType<XText>().FirstOrDefault()?.Value;
// remove [space...][//|///][space...] from functionDescription
// replace [^\S\r\n]+[\/]+\s* with empty string
functionDescription = System.Text.RegularExpressions.Regex.Replace(functionDescription, @"[^\S\r\n]+\/[\/]+\s*", string.Empty);
}
else
{
// if <summary> tag is not exist, then simply use the entire leading trivia as functionDescription
functionDescription = method.GetLeadingTrivia().ToString();
// remove [space...][//|///][space...] from functionDescription
// replace [^\S\r\n]+[\/]+\s* with empty string
functionDescription = System.Text.RegularExpressions.Regex.Replace(functionDescription, @"[^\S\r\n]+\/[\/]+\s*", string.Empty);
}
}
// get parameters
var parameters = method.ParameterList.Parameters.Select(parameter =>
{
var description = $"{parameter.Identifier}. type is {parameter.Type}";
// try to get parameter description from documentationCommentTrivia
var parameterDocumentationComment = documentationCommentTrivia?.GetParameterDescriptionFromDocumentationCommentTriviaSyntax(parameter.Identifier.ToString());
if (parameterDocumentationComment is not null)
{
description = parameterDocumentationComment.ToString();
// remove [space...][//|///][space...] from functionDescription
// replace [^\S\r\n]+[\/]+\s* with empty string
description = System.Text.RegularExpressions.Regex.Replace(description, @"[^\S\r\n]+\/[\/]+\s*", string.Empty);
}
var jsonItemType = parameter.Type!.ToString().EndsWith("[]") ? parameter.Type!.ToString().Substring(0, parameter.Type!.ToString().Length - 2) : null;
return new ParameterContract
{
Name = parameter.Identifier.ToString(),
JsonType = parameter.Type!.ToString() switch
{
"string" => "string",
"string[]" => "array",
"System.Int32" or "int" => "integer",
"System.Int64" or "long" => "integer",
"System.Single" or "float" => "number",
"System.Double" or "double" => "number",
"System.Boolean" or "bool" => "boolean",
"System.DateTime" => "string",
"System.Guid" => "string",
"System.Object" => "object",
_ => "object",
},
JsonItemType = jsonItemType,
Type = parameter.Type!.ToString(),
Description = description,
IsOptional = parameter.Default != null,
// if Default is null or "null", then DefaultValue is null
DefaultValue = parameter.Default?.ToString() == "null" ? null : parameter.Default?.Value.ToString(),
};
});
return new FunctionContract
{
ClassName = className,
Namespace = namespaceName,
Name = functionName,
Description = functionDescription?.Trim() ?? functionName,
Parameters = parameters.ToArray(),
ReturnType = method.ReturnType.ToString(),
};
}
}
}

View File

@@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionContract.cs
namespace AutoGen.SourceGenerator
{
internal class FunctionContract
{
public string? Namespace { get; set; }
public string? ClassName { get; set; }
public string? Name { get; set; }
public string? Description { get; set; }
public string? ReturnDescription { get; set; }
public ParameterContract[]? Parameters { get; set; }
public string? ReturnType { get; set; }
}
internal class ParameterContract
{
public string? Name { get; set; }
public string? Description { get; set; }
public string? JsonType { get; set; }
public string? JsonItemType { get; set; }
public string? Type { get; set; }
public bool IsOptional { get; set; }
public string? DefaultValue { get; set; }
}
}

View File

@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// FunctionExtension.cs
using AutoGen.SourceGenerator;
internal static class FunctionExtension
{
public static string GetFunctionName(this FunctionContract function)
{
return function.Name ?? string.Empty;
}
public static string GetFunctionSchemaClassName(this FunctionContract function)
{
return $"{function.GetFunctionName()}Schema";
}
public static string GetFunctionDefinitionName(this FunctionContract function)
{
return $"{function.GetFunctionName()}Function";
}
public static string GetFunctionWrapperName(this FunctionContract function)
{
return $"{function.GetFunctionName()}Wrapper";
}
public static string GetFunctionContractName(this FunctionContract function)
{
return $"{function.GetFunctionName()}FunctionContract";
}
}

View File

@@ -0,0 +1,113 @@
### AutoGen.SourceGenerator
This package carries a source generator that adds support for type-safe function definition generation. Simply mark a method with `Function` attribute, and the source generator will generate a function definition and a function call wrapper for you.
### Get start
First, add the following to your project file and set `GenerateDocumentationFile` property to true
```xml
<PropertyGroup>
<!-- This enables structural xml document support -->
<GenerateDocumentationFile>true</GenerateDocumentationFile>
</PropertyGroup>
```
```xml
<ItemGroup>
<PackageReference Include="AutoGen.SourceGenerator" />
</ItemGroup>
```
> Nightly Build feed: https://devdiv.pkgs.visualstudio.com/DevDiv/_packaging/AutoGen/nuget/v3/index.json
Then, for the methods you want to generate function definition and function call wrapper, mark them with `Function` attribute:
> Note: For the best of performance, try using primitive types for the parameters and return type.
```csharp
// file: MyFunctions.cs
using AutoGen;
// a partial class is required
// and the class must be public
public partial class MyFunctions
{
/// <summary>
/// Add two numbers.
/// </summary>
/// <param name="a">The first number.</param>
/// <param name="b">The second number.</param>
[Function]
public Task<string> AddAsync(int a, int b)
{
return Task.FromResult($"{a} + {b} = {a + b}");
}
}
```
The source generator will generate the following code based on the method signature and documentation. It helps you save the effort of writing function definition and keep it up to date with the actual method signature.
```csharp
// file: MyFunctions.generated.cs
public partial class MyFunctions
{
private class AddAsyncSchema
{
public int a {get; set;}
public int b {get; set;}
}
public Task<string> AddAsyncWrapper(string arguments)
{
var schema = JsonSerializer.Deserialize<AddAsyncSchema>(
arguments,
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});
return AddAsync(schema.a, schema.b);
}
public FunctionDefinition AddAsyncFunction
{
get => new FunctionDefinition
{
Name = @"AddAsync",
Description = """
Add two numbers.
""",
Parameters = BinaryData.FromObjectAsJson(new
{
Type = "object",
Properties = new
{
a = new
{
Type = @"number",
Description = @"The first number.",
},
b = new
{
Type = @"number",
Description = @"The second number.",
},
},
Required = new []
{
"a",
"b",
},
},
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
})
};
}
}
```
For more examples, please check out the following project
- [AutoGen.BasicSamples](../sample/AutoGen.BasicSamples/)
- [AutoGen.SourceGenerator.Tests](../../test/AutoGen.SourceGenerator.Tests/)

View File

@@ -0,0 +1,447 @@
// ------------------------------------------------------------------------------
// <auto-generated>
// This code was generated by a tool.
// Runtime Version: 17.0.0.0
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
// ------------------------------------------------------------------------------
namespace AutoGen.SourceGenerator.Template
{
using System.Linq;
using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using System;
/// <summary>
/// Class to produce the template output
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal partial class FunctionCallTemplate : FunctionCallTemplateBase
{
/// <summary>
/// Create the template output
/// </summary>
public virtual string TransformText()
{
this.Write("");
this.Write(@"//----------------------
// <auto-generated>
// This code was generated by a tool.
// </auto-generated>
//----------------------
using Azure.AI.OpenAI;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using System;
using AutoGen.Core;
using AutoGen.OpenAI.Extension;
");
if (!String.IsNullOrEmpty(NameSpace)) {
this.Write("namespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(NameSpace));
this.Write("\r\n{\r\n");
}
this.Write(" public partial class ");
this.Write(this.ToStringHelper.ToStringWithCulture(ClassName));
this.Write("\r\n {\r\n");
foreach (var functionContract in FunctionContracts) {
this.Write("\r\n private class ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionSchemaClassName()));
this.Write("\r\n {\r\n");
foreach (var parameter in functionContract.Parameters) {
if (parameter.IsOptional) {
this.Write(" [JsonPropertyName(@\"");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name));
this.Write("\")]\r\n\t\t\tpublic ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type));
this.Write(" ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name));
this.Write(" {get; set;} = ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.DefaultValue));
this.Write(";\r\n");
} else {
this.Write(" [JsonPropertyName(@\"");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name));
this.Write("\")]\r\n\t\t\tpublic ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type));
this.Write(" ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name));
this.Write(" {get; set;}\r\n");
}
}
this.Write(" }\r\n\r\n public ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.ReturnType));
this.Write(" ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionWrapperName()));
this.Write("(string arguments)\r\n {\r\n var schema = JsonSerializer.Deserializ" +
"e<");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionSchemaClassName()));
this.Write(">(\r\n arguments, \r\n new JsonSerializerOptions\r\n " +
" {\r\n PropertyNamingPolicy = JsonNamingPolicy.CamelC" +
"ase,\r\n });\r\n");
var argumentLists = string.Join(", ", functionContract.Parameters.Select(p => $"schema.{p.Name}"));
this.Write("\r\n return ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Name));
this.Write("(");
this.Write(this.ToStringHelper.ToStringWithCulture(argumentLists));
this.Write(");\r\n }\r\n\r\n public FunctionContract ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionContractName()));
this.Write("\r\n {\r\n get => new FunctionContract\r\n {\r\n");
if (functionContract.Namespace != null) {
this.Write(" Namespace = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Namespace));
this.Write("\",\r\n");
}
if (functionContract.ClassName != null) {
this.Write(" ClassName = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.ClassName));
this.Write("\",\r\n");
}
if (functionContract.Name != null) {
this.Write(" Name = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Name));
this.Write("\",\r\n");
}
if (functionContract.Description != null) {
this.Write(" Description = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.Description));
this.Write("\",\r\n");
}
if (functionContract.ReturnType != null) {
this.Write(" ReturnType = typeof(");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.ReturnType));
this.Write("),\r\n");
}
if (functionContract.ReturnDescription != null) {
this.Write(" ReturnDescription = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.ReturnDescription));
this.Write("\",\r\n");
}
if (functionContract.Parameters != null) {
this.Write(" Parameters = new []\r\n {\r\n");
foreach (var parameter in functionContract.Parameters) {
this.Write(" new FunctionParameterContract\r\n {\r\n");
if (parameter.Name != null) {
this.Write(" Name = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Name));
this.Write("\",\r\n");
}
if (parameter.Description != null) {
this.Write(" Description = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Description));
this.Write("\",\r\n");
}
if (parameter.Type != null) {
this.Write(" ParameterType = typeof(");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.Type));
this.Write("),\r\n");
}
this.Write(" IsRequired = ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.IsOptional ? "false" : "true"));
this.Write(",\r\n");
if (parameter.DefaultValue != null) {
this.Write(" DefaultValue = ");
this.Write(this.ToStringHelper.ToStringWithCulture(parameter.DefaultValue));
this.Write(",\r\n");
}
this.Write(" },\r\n");
}
this.Write(" },\r\n");
}
this.Write(" };\r\n }\r\n\r\n public Azure.AI.OpenAI.FunctionDefinition ");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionDefinitionName()));
this.Write("\r\n {\r\n get => this.");
this.Write(this.ToStringHelper.ToStringWithCulture(functionContract.GetFunctionContractName()));
this.Write(".ToOpenAIFunctionDefinition();\r\n }\r\n");
}
this.Write(" }\r\n");
if (!String.IsNullOrEmpty(NameSpace)) {
this.Write("}\r\n");
}
this.Write("\r\n");
return this.GenerationEnvironment.ToString();
}
public string NameSpace {get; set;}
public string ClassName {get; set;}
public IEnumerable<FunctionContract> FunctionContracts {get; set;}
public bool IsStatic {get; set;} = false;
}
#region Base class
/// <summary>
/// Base class for this transformation
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "17.0.0.0")]
internal class FunctionCallTemplateBase
{
#region Fields
private global::System.Text.StringBuilder generationEnvironmentField;
private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField;
private global::System.Collections.Generic.List<int> indentLengthsField;
private string currentIndentField = "";
private bool endsWithNewline;
private global::System.Collections.Generic.IDictionary<string, object> sessionField;
#endregion
#region Properties
/// <summary>
/// The string builder that generation-time code is using to assemble generated output
/// </summary>
public System.Text.StringBuilder GenerationEnvironment
{
get
{
if ((this.generationEnvironmentField == null))
{
this.generationEnvironmentField = new global::System.Text.StringBuilder();
}
return this.generationEnvironmentField;
}
set
{
this.generationEnvironmentField = value;
}
}
/// <summary>
/// The error collection for the generation process
/// </summary>
public System.CodeDom.Compiler.CompilerErrorCollection Errors
{
get
{
if ((this.errorsField == null))
{
this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection();
}
return this.errorsField;
}
}
/// <summary>
/// A list of the lengths of each indent that was added with PushIndent
/// </summary>
private System.Collections.Generic.List<int> indentLengths
{
get
{
if ((this.indentLengthsField == null))
{
this.indentLengthsField = new global::System.Collections.Generic.List<int>();
}
return this.indentLengthsField;
}
}
/// <summary>
/// Gets the current indent we use when adding lines to the output
/// </summary>
public string CurrentIndent
{
get
{
return this.currentIndentField;
}
}
/// <summary>
/// Current transformation session
/// </summary>
public virtual global::System.Collections.Generic.IDictionary<string, object> Session
{
get
{
return this.sessionField;
}
set
{
this.sessionField = value;
}
}
#endregion
#region Transform-time helpers
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void Write(string textToAppend)
{
if (string.IsNullOrEmpty(textToAppend))
{
return;
}
// If we're starting off, or if the previous text ended with a newline,
// we have to append the current indent first.
if (((this.GenerationEnvironment.Length == 0)
|| this.endsWithNewline))
{
this.GenerationEnvironment.Append(this.currentIndentField);
this.endsWithNewline = false;
}
// Check if the current text ends with a newline
if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture))
{
this.endsWithNewline = true;
}
// This is an optimization. If the current indent is "", then we don't have to do any
// of the more complex stuff further down.
if ((this.currentIndentField.Length == 0))
{
this.GenerationEnvironment.Append(textToAppend);
return;
}
// Everywhere there is a newline in the text, add an indent after it
textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField));
// If the text ends with a newline, then we should strip off the indent added at the very end
// because the appropriate indent will be added when the next time Write() is called
if (this.endsWithNewline)
{
this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length));
}
else
{
this.GenerationEnvironment.Append(textToAppend);
}
}
/// <summary>
/// Write text directly into the generated output
/// </summary>
public void WriteLine(string textToAppend)
{
this.Write(textToAppend);
this.GenerationEnvironment.AppendLine();
this.endsWithNewline = true;
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void Write(string format, params object[] args)
{
this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Write formatted text directly into the generated output
/// </summary>
public void WriteLine(string format, params object[] args)
{
this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args));
}
/// <summary>
/// Raise an error
/// </summary>
public void Error(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
this.Errors.Add(error);
}
/// <summary>
/// Raise a warning
/// </summary>
public void Warning(string message)
{
System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError();
error.ErrorText = message;
error.IsWarning = true;
this.Errors.Add(error);
}
/// <summary>
/// Increase the indent
/// </summary>
public void PushIndent(string indent)
{
if ((indent == null))
{
throw new global::System.ArgumentNullException("indent");
}
this.currentIndentField = (this.currentIndentField + indent);
this.indentLengths.Add(indent.Length);
}
/// <summary>
/// Remove the last indent that was added with PushIndent
/// </summary>
public string PopIndent()
{
string returnValue = "";
if ((this.indentLengths.Count > 0))
{
int indentLength = this.indentLengths[(this.indentLengths.Count - 1)];
this.indentLengths.RemoveAt((this.indentLengths.Count - 1));
if ((indentLength > 0))
{
returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength));
this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength));
}
}
return returnValue;
}
/// <summary>
/// Remove any indentation
/// </summary>
public void ClearIndent()
{
this.indentLengths.Clear();
this.currentIndentField = "";
}
#endregion
#region ToString Helpers
/// <summary>
/// Utility class to produce culture-oriented representation of an object as a string.
/// </summary>
public class ToStringInstanceHelper
{
private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture;
/// <summary>
/// Gets or sets format provider to be used by ToStringWithCulture method.
/// </summary>
public System.IFormatProvider FormatProvider
{
get
{
return this.formatProviderField ;
}
set
{
if ((value != null))
{
this.formatProviderField = value;
}
}
}
/// <summary>
/// This is called from the compile/run appdomain to convert objects within an expression block to a string
/// </summary>
public string ToStringWithCulture(object objectToConvert)
{
if ((objectToConvert == null))
{
throw new global::System.ArgumentNullException("objectToConvert");
}
System.Type t = objectToConvert.GetType();
System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] {
typeof(System.IFormatProvider)});
if ((method == null))
{
return objectToConvert.ToString();
}
else
{
return ((string)(method.Invoke(objectToConvert, new object[] {
this.formatProviderField })));
}
}
}
private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper();
/// <summary>
/// Helper to produce culture-oriented representation of an object as a string
/// </summary>
public ToStringInstanceHelper ToStringHelper
{
get
{
return this.toStringHelperField;
}
}
#endregion
}
#endregion
}

View File

@@ -0,0 +1,116 @@
<#@ template language="C#" linePragmas="false" visibility = "internal" #>
<#@ assembly name="System.Core" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ import namespace="Microsoft.CodeAnalysis" #>
//----------------------
// <auto-generated>
// This code was generated by a tool.
// </auto-generated>
//----------------------
using Azure.AI.OpenAI;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using System;
using AutoGen.Core;
using AutoGen.OpenAI.Extension;
<#if (!String.IsNullOrEmpty(NameSpace)) {#>
namespace <#=NameSpace#>
{
<#}#>
public partial class <#=ClassName#>
{
<#foreach (var functionContract in FunctionContracts) {#>
private class <#=functionContract.GetFunctionSchemaClassName()#>
{
<#foreach (var parameter in functionContract.Parameters) {#>
<#if (parameter.IsOptional) {#>
[JsonPropertyName(@"<#=parameter.Name#>")]
public <#=parameter.Type#> <#=parameter.Name#> {get; set;} = <#=parameter.DefaultValue#>;
<#} else {#>
[JsonPropertyName(@"<#=parameter.Name#>")]
public <#=parameter.Type#> <#=parameter.Name#> {get; set;}
<#}#>
<#}#>
}
public <#=functionContract.ReturnType#> <#=functionContract.GetFunctionWrapperName()#>(string arguments)
{
var schema = JsonSerializer.Deserialize<<#=functionContract.GetFunctionSchemaClassName()#>>(
arguments,
new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
});
<# var argumentLists = string.Join(", ", functionContract.Parameters.Select(p => $"schema.{p.Name}")); #>
return <#=functionContract.Name#>(<#=argumentLists#>);
}
public FunctionContract <#=functionContract.GetFunctionContractName()#>
{
get => new FunctionContract
{
<#if (functionContract.Namespace != null) {#>
Namespace = @"<#=functionContract.Namespace#>",
<#}#>
<#if (functionContract.ClassName != null) {#>
ClassName = @"<#=functionContract.ClassName#>",
<#}#>
<#if (functionContract.Name != null) {#>
Name = @"<#=functionContract.Name#>",
<#}#>
<#if (functionContract.Description != null) {#>
Description = @"<#=functionContract.Description#>",
<#}#>
<#if (functionContract.ReturnType != null) {#>
ReturnType = typeof(<#=functionContract.ReturnType#>),
<#}#>
<#if (functionContract.ReturnDescription != null) {#>
ReturnDescription = @"<#=functionContract.ReturnDescription#>",
<#}#>
<#if (functionContract.Parameters != null) {#>
Parameters = new []
{
<#foreach (var parameter in functionContract.Parameters) {#>
new FunctionParameterContract
{
<#if (parameter.Name != null) {#>
Name = @"<#=parameter.Name#>",
<#}#>
<#if (parameter.Description != null) {#>
Description = @"<#=parameter.Description#>",
<#}#>
<#if (parameter.Type != null) {#>
ParameterType = typeof(<#=parameter.Type#>),
<#}#>
IsRequired = <#=parameter.IsOptional ? "false" : "true"#>,
<#if (parameter.DefaultValue != null) {#>
DefaultValue = <#=parameter.DefaultValue#>,
<#}#>
},
<#}#>
},
<#}#>
};
}
public Azure.AI.OpenAI.FunctionDefinition <#=functionContract.GetFunctionDefinitionName()#>
{
get => this.<#=functionContract.GetFunctionContractName()#>.ToOpenAIFunctionDefinition();
}
<#}#>
}
<#if (!String.IsNullOrEmpty(NameSpace)) {#>
}
<#}#>
<#+
public string NameSpace {get; set;}
public string ClassName {get; set;}
public IEnumerable<FunctionContract> FunctionContracts {get; set;}
public bool IsStatic {get; set;} = false;
#>

View File

@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// LLMConfigAPI.cs
using System;
using System.Collections.Generic;
using System.Linq;
using AutoGen.OpenAI;
namespace AutoGen
{
public static class LLMConfigAPI
{
public static IEnumerable<ILLMConfig> GetOpenAIConfigList(
string apiKey,
IEnumerable<string>? modelIDs = null)
{
var models = modelIDs ?? new[]
{
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-1106-preview",
};
return models.Select(modelId => new OpenAIConfig(apiKey, modelId));
}
public static IEnumerable<ILLMConfig> GetAzureOpenAIConfigList(
string endpoint,
string apiKey,
IEnumerable<string> deploymentNames)
{
return deploymentNames.Select(deploymentName => new AzureOpenAIConfig(endpoint, deploymentName, apiKey));
}
/// <summary>
/// Get a list of LLMConfig objects from a JSON file.
/// </summary>
internal static IEnumerable<ILLMConfig> ConfigListFromJson(
string filePath,
IEnumerable<string>? filterModels = null)
{
// Disable this API from documentation for now.
throw new NotImplementedException();
}
}
}

View File

@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// AssistantAgent.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen;
public class AssistantAgent : ConversableAgent
{
public AssistantAgent(
string name,
string systemMessage = "You are a helpful AI assistant",
ConversableAgentConfig? llmConfig = null,
Func<IEnumerable<IMessage>, CancellationToken, Task<bool>>? isTermination = null,
HumanInputMode humanInputMode = HumanInputMode.NEVER,
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
string? defaultReply = null)
: base(name: name,
systemMessage: systemMessage,
llmConfig: llmConfig,
isTermination: isTermination,
humanInputMode: humanInputMode,
functionMap: functionMap,
defaultReply: defaultReply)
{
}
}

View File

@@ -0,0 +1,156 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ConversableAgent.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using AutoGen.OpenAI;
namespace AutoGen;
public enum HumanInputMode
{
/// <summary>
/// NEVER prompt the user for input
/// </summary>
NEVER = 0,
/// <summary>
/// ALWAYS prompt the user for input
/// </summary>
ALWAYS = 1,
/// <summary>
/// prompt the user for input if the message is not a termination message
/// </summary>
AUTO = 2,
}
public class ConversableAgent : IAgent
{
private readonly IAgent? innerAgent;
private readonly string? defaultReply;
private readonly HumanInputMode humanInputMode;
private readonly IDictionary<string, Func<string, Task<string>>>? functionMap;
private readonly string systemMessage;
private readonly IEnumerable<FunctionContract>? functions;
public ConversableAgent(
string name,
string systemMessage = "You are a helpful AI assistant",
IAgent? innerAgent = null,
string? defaultAutoReply = null,
HumanInputMode humanInputMode = HumanInputMode.NEVER,
Func<IEnumerable<IMessage>, CancellationToken, Task<bool>>? isTermination = null,
IDictionary<string, Func<string, Task<string>>>? functionMap = null)
{
this.Name = name;
this.defaultReply = defaultAutoReply;
this.functionMap = functionMap;
this.humanInputMode = humanInputMode;
this.innerAgent = innerAgent;
this.IsTermination = isTermination;
this.systemMessage = systemMessage;
}
public ConversableAgent(
string name,
string systemMessage = "You are a helpful AI assistant",
ConversableAgentConfig? llmConfig = null,
Func<IEnumerable<IMessage>, CancellationToken, Task<bool>>? isTermination = null,
HumanInputMode humanInputMode = HumanInputMode.AUTO,
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
string? defaultReply = null)
{
this.Name = name;
this.defaultReply = defaultReply;
this.functionMap = functionMap;
this.humanInputMode = humanInputMode;
this.IsTermination = isTermination;
this.systemMessage = systemMessage;
this.innerAgent = llmConfig?.ConfigList != null ? this.CreateInnerAgentFromConfigList(llmConfig) : null;
this.functions = llmConfig?.FunctionContracts;
}
private IAgent? CreateInnerAgentFromConfigList(ConversableAgentConfig config)
{
IAgent? agent = null;
foreach (var llmConfig in config.ConfigList ?? Enumerable.Empty<ILLMConfig>())
{
agent = agent switch
{
null => llmConfig switch
{
AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0),
OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0),
_ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"),
},
IAgent innerAgent => innerAgent.RegisterReply(async (messages, cancellationToken) =>
{
return await innerAgent.GenerateReplyAsync(messages, cancellationToken: cancellationToken);
}),
};
}
return agent;
}
public string Name { get; }
public Func<IEnumerable<IMessage>, CancellationToken, Task<bool>>? IsTermination { get; }
public async Task<IMessage> GenerateReplyAsync(
IEnumerable<IMessage> messages,
GenerateReplyOptions? overrideOptions = null,
CancellationToken cancellationToken = default)
{
// if there's no system message, add system message to the first of chat history
if (!messages.Any(m => m.IsSystemMessage()))
{
var systemMessage = new TextMessage(Role.System, this.systemMessage, from: this.Name);
messages = new[] { systemMessage }.Concat(messages);
}
// process order: function_call -> human_input -> inner_agent -> default_reply -> self_execute
// first in, last out
// process default reply
MiddlewareAgent agent;
if (this.innerAgent != null)
{
agent = innerAgent.RegisterMiddleware(async (msgs, option, agent, ct) =>
{
var updatedMessages = msgs.Select(m =>
{
if (m.From == this.Name)
{
m.From = this.innerAgent.Name;
return m;
}
else
{
return m;
}
});
return await agent.GenerateReplyAsync(updatedMessages, option, ct);
});
}
else
{
agent = new MiddlewareAgent<DefaultReplyAgent>(new DefaultReplyAgent(this.Name!, this.defaultReply ?? "Default reply is not set. Please pass a default reply to assistant agent"));
}
// process human input
var humanInputMiddleware = new HumanInputMiddleware(mode: this.humanInputMode, isTermination: this.IsTermination);
agent.Use(humanInputMiddleware);
// process function call
var functionCallMiddleware = new FunctionCallMiddleware(functions: this.functions, functionMap: this.functionMap);
agent.Use(functionCallMiddleware);
return await agent.GenerateReplyAsync(messages, overrideOptions, cancellationToken);
}
}

View File

@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// UserProxyAgent.cs
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace AutoGen;
public class UserProxyAgent : ConversableAgent
{
public UserProxyAgent(
string name,
string systemMessage = "You are a helpful AI assistant",
ConversableAgentConfig? llmConfig = null,
Func<IEnumerable<IMessage>, CancellationToken, Task<bool>>? isTermination = null,
HumanInputMode humanInputMode = HumanInputMode.ALWAYS,
IDictionary<string, Func<string, Task<string>>>? functionMap = null,
string? defaultReply = null)
: base(name: name,
systemMessage: systemMessage,
llmConfig: llmConfig,
isTermination: isTermination,
humanInputMode: humanInputMode,
functionMap: functionMap,
defaultReply: defaultReply)
{
}
}

View File

@@ -0,0 +1,31 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<RootNamespace>AutoGen</RootNamespace>
</PropertyGroup>
<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>AutoGen</Title>
<Description>
The all-in-one package for AutoGen. This package provides contracts, core functionalities, OpenAI integration, source generator, etc. for AutoGen.
</Description>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="JsonSchema.Net.Generation" Version="$(JsonSchemaVersion)" />
<ProjectReference Include="..\AutoGen.LMStudio\AutoGen.LMStudio.csproj" />
<ProjectReference Include="..\AutoGen.Mistral\AutoGen.Mistral.csproj" />
<ProjectReference Include="..\AutoGen.SemanticKernel\AutoGen.SemanticKernel.csproj" />
<ProjectReference Include="..\AutoGen.SourceGenerator\AutoGen.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\AutoGen.Core\AutoGen.Core.csproj" />
<ProjectReference Include="..\AutoGen.OpenAI\AutoGen.OpenAI.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ConversableAgentConfig.cs
using System.Collections.Generic;
namespace AutoGen;
public class ConversableAgentConfig
{
public IEnumerable<FunctionContract>? FunctionContracts { get; set; }
public IEnumerable<ILLMConfig>? ConfigList { get; set; }
public float? Temperature { get; set; } = 0.7f;
public int? Timeout { get; set; }
}

View File

@@ -0,0 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// GlobalUsing.cs
global using AutoGen.Core;

Some files were not shown because too many files have changed in this diff Show More