Add * before keyword args for ChatCompletionClient (#4822)

add * before keyword args

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Leonardo Pinheiro
2024-12-27 23:41:16 +10:00
committed by GitHub
parent edad1b6065
commit 9a2dbb4fba
4 changed files with 20 additions and 12 deletions

View File

@@ -29,6 +29,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
@@ -41,6 +42,7 @@ class ChatCompletionClient(ABC, ComponentLoader):
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
@@ -56,10 +58,10 @@ class ChatCompletionClient(ABC, ComponentLoader):
def total_usage(self) -> RequestUsage: ...
@abstractmethod
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
@abstractmethod
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: ...
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
@property
@abstractmethod

View File

@@ -92,6 +92,7 @@ async def test_caller_loop() -> None:
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@@ -116,6 +117,7 @@ async def test_caller_loop() -> None:
def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
@@ -129,10 +131,10 @@ async def test_caller_loop() -> None:
def total_usage(self) -> RequestUsage:
return RequestUsage(prompt_tokens=0, completion_tokens=0)
def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int:
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return 0
@property