mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-10 05:27:54 -05:00
Merge branch 'main' of github.com:Pythagora-io/copilot
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import requests
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import tiktoken
|
||||
import questionary
|
||||
|
||||
from typing import List
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
@@ -39,50 +42,54 @@ def get_prompt(prompt_name, data=None):
|
||||
|
||||
|
||||
def get_tokens_in_messages(messages: List[str]) -> int:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
|
||||
tokenized_messages = [tokenizer.encode(message['content']) for message in messages]
|
||||
return sum(len(tokens) for tokens in tokenized_messages)
|
||||
|
||||
|
||||
def num_tokens_from_functions(functions, model="gpt-4"):
|
||||
"""Return the number of tokens used by a list of functions."""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
"""Return the number of tokens used by a list of functions."""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for function in functions:
|
||||
function_tokens = len(encoding.encode(function['name']))
|
||||
function_tokens += len(encoding.encode(function['description']))
|
||||
num_tokens = 0
|
||||
for function in functions:
|
||||
function_tokens = len(encoding.encode(function['name']))
|
||||
function_tokens += len(encoding.encode(function['description']))
|
||||
|
||||
if 'parameters' in function:
|
||||
parameters = function['parameters']
|
||||
if 'properties' in parameters:
|
||||
for propertiesKey in parameters['properties']:
|
||||
function_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters['properties'][propertiesKey]
|
||||
for field in v:
|
||||
if field == 'type':
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v['type']))
|
||||
elif field == 'description':
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v['description']))
|
||||
elif field == 'enum':
|
||||
function_tokens -= 3
|
||||
for o in v['enum']:
|
||||
function_tokens += 3
|
||||
function_tokens += len(encoding.encode(o))
|
||||
else:
|
||||
print(f"Warning: not supported field {field}")
|
||||
function_tokens += 11
|
||||
if 'parameters' in function:
|
||||
parameters = function['parameters']
|
||||
if 'properties' in parameters:
|
||||
for propertiesKey in parameters['properties']:
|
||||
function_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters['properties'][propertiesKey]
|
||||
for field in v:
|
||||
if field == 'type':
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v['type']))
|
||||
elif field == 'description':
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v['description']))
|
||||
elif field == 'enum':
|
||||
function_tokens -= 3
|
||||
for o in v['enum']:
|
||||
function_tokens += 3
|
||||
function_tokens += len(encoding.encode(o))
|
||||
else:
|
||||
print(f"Warning: not supported field {field}")
|
||||
function_tokens += 11
|
||||
|
||||
num_tokens += function_tokens
|
||||
num_tokens += function_tokens
|
||||
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE, function_calls=None):
|
||||
tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy
|
||||
|
||||
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE,
|
||||
function_calls=None):
|
||||
tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy
|
||||
if function_calls is not None:
|
||||
tokens_in_messages += round(num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy
|
||||
tokens_in_messages += round(
|
||||
num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy
|
||||
if tokens_in_messages + min_tokens > MAX_GPT_MODEL_TOKENS:
|
||||
raise ValueError(f'Too many tokens in messages: {tokens_in_messages}. Please try a different test.')
|
||||
|
||||
@@ -114,13 +121,34 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
|
||||
print(e)
|
||||
|
||||
|
||||
def delete_last_n_lines(n):
|
||||
for _ in range(n):
|
||||
# Move the cursor up one line
|
||||
sys.stdout.write('\033[F')
|
||||
# Clear the current line
|
||||
sys.stdout.write('\033[K')
|
||||
|
||||
|
||||
def count_lines_based_on_width(content, width):
|
||||
lines_required = sum(len(line) // width + 1 for line in content.split('\n'))
|
||||
return lines_required
|
||||
|
||||
|
||||
def stream_gpt_completion(data, req_type):
|
||||
def return_result(result_data):
|
||||
# spinner_stop(spinner)
|
||||
terminal_width = os.get_terminal_size().columns
|
||||
lines_printed = 2
|
||||
buffer = "" # A buffer to accumulate incoming data
|
||||
|
||||
def return_result(result_data, lines_printed):
|
||||
if buffer:
|
||||
lines_printed += count_lines_based_on_width(buffer, terminal_width)
|
||||
logger.info(f'lines printed: {lines_printed} - {terminal_width}')
|
||||
|
||||
delete_last_n_lines(lines_printed)
|
||||
return result_data
|
||||
|
||||
# spinner = spinner_start(colored("Waiting for OpenAI API response...", 'yellow'))
|
||||
colored("Waiting for OpenAI API response...", 'yellow')
|
||||
print(colored("Waiting for OpenAI API response...", 'yellow'))
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
logger.info(f'Request data: {data}')
|
||||
@@ -136,9 +164,21 @@ def stream_gpt_completion(data, req_type):
|
||||
logger.info(f'Response status code: {response.status_code}')
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f'problem with request: {response.text}')
|
||||
print(colored(f'There was a problem with request to openai API:', 'red'))
|
||||
print(response.text)
|
||||
user_message = questionary.text("Do you want to try make same request again? If yes, just press ENTER.",
|
||||
style=questionary.Style([
|
||||
('question', 'fg:red'),
|
||||
('answer', 'fg:orange')
|
||||
])).ask()
|
||||
|
||||
lines_printed += count_lines_based_on_width(response.text, terminal_width) + 1
|
||||
if user_message == '':
|
||||
delete_last_n_lines(lines_printed)
|
||||
return stream_gpt_completion(data, req_type)
|
||||
|
||||
logger.debug(f'problem with request: {response.text}')
|
||||
return return_result({})
|
||||
return return_result({}, lines_printed)
|
||||
|
||||
gpt_response = ''
|
||||
function_calls = {'name': '', 'arguments': ''}
|
||||
@@ -163,7 +203,7 @@ def stream_gpt_completion(data, req_type):
|
||||
|
||||
if json_line['choices'][0]['finish_reason'] == 'function_call':
|
||||
function_calls['arguments'] = load_data_to_json(function_calls['arguments'])
|
||||
return return_result({'function_calls': function_calls});
|
||||
return return_result({'function_calls': function_calls}, lines_printed);
|
||||
|
||||
json_line = json_line['choices'][0]['delta']
|
||||
|
||||
@@ -183,6 +223,13 @@ def stream_gpt_completion(data, req_type):
|
||||
if 'content' in json_line:
|
||||
content = json_line.get('content')
|
||||
if content:
|
||||
buffer += content # accumulate the data
|
||||
|
||||
# If you detect a natural breakpoint (e.g., line break or end of a response object), print & count:
|
||||
if buffer.endswith("\n"): # or some other condition that denotes a breakpoint
|
||||
lines_printed += count_lines_based_on_width(buffer, terminal_width)
|
||||
buffer = "" # reset the buffer
|
||||
|
||||
gpt_response += content
|
||||
print(content, end='', flush=True)
|
||||
|
||||
@@ -190,10 +237,10 @@ def stream_gpt_completion(data, req_type):
|
||||
if function_calls['arguments'] != '':
|
||||
logger.info(f'Response via function call: {function_calls["arguments"]}')
|
||||
function_calls['arguments'] = load_data_to_json(function_calls['arguments'])
|
||||
return return_result({'function_calls': function_calls})
|
||||
return return_result({'function_calls': function_calls}, lines_printed)
|
||||
logger.info(f'Response message: {gpt_response}')
|
||||
new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically
|
||||
return return_result({'text': new_code})
|
||||
return return_result({'text': new_code}, lines_printed)
|
||||
|
||||
|
||||
def postprocessing(gpt_response, req_type):
|
||||
|
||||
Reference in New Issue
Block a user