mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-08 12:53:50 -05:00
This is a complete rewrite of the GPT Pilot core, from the ground up, making the agentic architecture front and center, and also fixing some long-standing problems with the database architecture that weren't feasible to solve without breaking compatibility. As the database structure and config file syntax have changed, we have automatic imports for projects and current configs, see the README.md file for details. This also relicenses the project to FSL-1.1-MIT license.
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
import datetime
|
|
import re
|
|
from typing import Optional
|
|
|
|
import tiktoken
|
|
from httpx import Timeout
|
|
from openai import AsyncOpenAI, RateLimitError
|
|
|
|
from core.config import LLMProvider
|
|
from core.llm.base import BaseLLMClient
|
|
from core.llm.convo import Convo
|
|
from core.log import get_logger
|
|
|
|
log = get_logger(__name__)
|
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
class OpenAIClient(BaseLLMClient):
|
|
provider = LLMProvider.OPENAI
|
|
|
|
def _init_client(self):
|
|
self.client = AsyncOpenAI(
|
|
api_key=self.config.api_key,
|
|
base_url=self.config.base_url,
|
|
timeout=Timeout(
|
|
max(self.config.connect_timeout, self.config.read_timeout),
|
|
connect=self.config.connect_timeout,
|
|
read=self.config.read_timeout,
|
|
),
|
|
)
|
|
|
|
async def _make_request(
|
|
self,
|
|
convo: Convo,
|
|
temperature: Optional[float] = None,
|
|
json_mode: bool = False,
|
|
) -> tuple[str, int, int]:
|
|
completion_kwargs = {
|
|
"model": self.config.model,
|
|
"messages": convo.messages,
|
|
"temperature": self.config.temperature if temperature is None else temperature,
|
|
"stream": True,
|
|
"stream_options": {
|
|
"include_usage": True,
|
|
},
|
|
}
|
|
if json_mode:
|
|
completion_kwargs["response_format"] = {"type": "json_object"}
|
|
|
|
stream = await self.client.chat.completions.create(**completion_kwargs)
|
|
response = []
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
|
|
async for chunk in stream:
|
|
if chunk.usage:
|
|
prompt_tokens += chunk.usage.prompt_tokens
|
|
completion_tokens += chunk.usage.completion_tokens
|
|
|
|
if not chunk.choices:
|
|
continue
|
|
|
|
content = chunk.choices[0].delta.content
|
|
if not content:
|
|
continue
|
|
|
|
response.append(content)
|
|
if self.stream_handler:
|
|
await self.stream_handler(content)
|
|
|
|
response_str = "".join(response)
|
|
|
|
# Tell the stream handler we're done
|
|
if self.stream_handler:
|
|
await self.stream_handler(None)
|
|
|
|
if prompt_tokens == 0 and completion_tokens == 0:
|
|
# See https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken
|
|
prompt_tokens = sum(3 + len(tokenizer.encode(msg["content"])) for msg in convo.messages)
|
|
completion_tokens = len(tokenizer.encode(response_str))
|
|
log.warning(
|
|
"OpenAI response did not include token counts, estimating with tiktoken: "
|
|
f"{prompt_tokens} input tokens, {completion_tokens} output tokens"
|
|
)
|
|
|
|
return response_str, prompt_tokens, completion_tokens
|
|
|
|
def rate_limit_sleep(self, err: RateLimitError) -> Optional[datetime.timedelta]:
|
|
"""
|
|
OpenAI rate limits docs:
|
|
https://platform.openai.com/docs/guides/rate-limits/error-mitigation
|
|
Limit reset times are in "2h32m54s" format.
|
|
"""
|
|
|
|
headers = err.response.headers
|
|
if "x-ratelimit-remaining-tokens" not in headers:
|
|
return None
|
|
|
|
remaining_tokens = headers["x-ratelimit-remaining-tokens"]
|
|
time_regex = r"(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?"
|
|
if remaining_tokens == 0:
|
|
match = re.search(time_regex, headers["x-ratelimit-reset-tokens"])
|
|
else:
|
|
match = re.search(time_regex, headers["x-ratelimit-reset-requests"])
|
|
|
|
if match:
|
|
seconds = int(match.group(1)) * 3600 + int(match.group(2)) * 60 + int(match.group(3))
|
|
else:
|
|
# Not sure how this would happen, we would have to get a RateLimitError,
|
|
# but nothing (or invalid entry) in the `reset` field. Using a sane default.
|
|
seconds = 5
|
|
|
|
return datetime.timedelta(seconds=seconds)
|
|
|
|
|
|
__all__ = ["OpenAIClient"]
|