Merge branch 'main' of github.com:Pythagora-io/copilot

This commit is contained in:
Zvonimir Sabljic
2023-08-14 18:23:04 +02:00

View File

@@ -1,7 +1,10 @@
import requests
import os
import sys
import json
import tiktoken
import questionary
from typing import List
from jinja2 import Environment, FileSystemLoader
@@ -39,50 +42,54 @@ def get_prompt(prompt_name, data=None):
def get_tokens_in_messages(messages: List[str]) -> int:
tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
tokenizer = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
tokenized_messages = [tokenizer.encode(message['content']) for message in messages]
return sum(len(tokens) for tokens in tokenized_messages)
def num_tokens_from_functions(functions, model="gpt-4"):
"""Return the number of tokens used by a list of functions."""
encoding = tiktoken.get_encoding("cl100k_base")
"""Return the number of tokens used by a list of functions."""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function['name']))
function_tokens += len(encoding.encode(function['description']))
num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function['name']))
function_tokens += len(encoding.encode(function['description']))
if 'parameters' in function:
parameters = function['parameters']
if 'properties' in parameters:
for propertiesKey in parameters['properties']:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters['properties'][propertiesKey]
for field in v:
if field == 'type':
function_tokens += 2
function_tokens += len(encoding.encode(v['type']))
elif field == 'description':
function_tokens += 2
function_tokens += len(encoding.encode(v['description']))
elif field == 'enum':
function_tokens -= 3
for o in v['enum']:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
print(f"Warning: not supported field {field}")
function_tokens += 11
if 'parameters' in function:
parameters = function['parameters']
if 'properties' in parameters:
for propertiesKey in parameters['properties']:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters['properties'][propertiesKey]
for field in v:
if field == 'type':
function_tokens += 2
function_tokens += len(encoding.encode(v['type']))
elif field == 'description':
function_tokens += 2
function_tokens += len(encoding.encode(v['description']))
elif field == 'enum':
function_tokens -= 3
for o in v['enum']:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
print(f"Warning: not supported field {field}")
function_tokens += 11
num_tokens += function_tokens
num_tokens += function_tokens
num_tokens += 12
return num_tokens
num_tokens += 12
return num_tokens
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE, function_calls=None):
tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE,
function_calls=None):
tokens_in_messages = round(get_tokens_in_messages(messages) * 1.2) # add 20% to account for not 100% accuracy
if function_calls is not None:
tokens_in_messages += round(num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy
tokens_in_messages += round(
num_tokens_from_functions(function_calls['definitions']) * 1.2) # add 20% to account for not 100% accuracy
if tokens_in_messages + min_tokens > MAX_GPT_MODEL_TOKENS:
raise ValueError(f'Too many tokens in messages: {tokens_in_messages}. Please try a different test.')
@@ -114,13 +121,34 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
print(e)
def delete_last_n_lines(n):
for _ in range(n):
# Move the cursor up one line
sys.stdout.write('\033[F')
# Clear the current line
sys.stdout.write('\033[K')
def count_lines_based_on_width(content, width):
lines_required = sum(len(line) // width + 1 for line in content.split('\n'))
return lines_required
def stream_gpt_completion(data, req_type):
def return_result(result_data):
# spinner_stop(spinner)
terminal_width = os.get_terminal_size().columns
lines_printed = 2
buffer = "" # A buffer to accumulate incoming data
def return_result(result_data, lines_printed):
if buffer:
lines_printed += count_lines_based_on_width(buffer, terminal_width)
logger.info(f'lines printed: {lines_printed} - {terminal_width}')
delete_last_n_lines(lines_printed)
return result_data
# spinner = spinner_start(colored("Waiting for OpenAI API response...", 'yellow'))
colored("Waiting for OpenAI API response...", 'yellow')
print(colored("Waiting for OpenAI API response...", 'yellow'))
api_key = os.getenv("OPENAI_API_KEY")
logger.info(f'Request data: {data}')
@@ -136,9 +164,21 @@ def stream_gpt_completion(data, req_type):
logger.info(f'Response status code: {response.status_code}')
if response.status_code != 200:
print(f'problem with request: {response.text}')
print(colored(f'There was a problem with request to openai API:', 'red'))
print(response.text)
user_message = questionary.text("Do you want to try make same request again? If yes, just press ENTER.",
style=questionary.Style([
('question', 'fg:red'),
('answer', 'fg:orange')
])).ask()
lines_printed += count_lines_based_on_width(response.text, terminal_width) + 1
if user_message == '':
delete_last_n_lines(lines_printed)
return stream_gpt_completion(data, req_type)
logger.debug(f'problem with request: {response.text}')
return return_result({})
return return_result({}, lines_printed)
gpt_response = ''
function_calls = {'name': '', 'arguments': ''}
@@ -163,7 +203,7 @@ def stream_gpt_completion(data, req_type):
if json_line['choices'][0]['finish_reason'] == 'function_call':
function_calls['arguments'] = load_data_to_json(function_calls['arguments'])
return return_result({'function_calls': function_calls});
return return_result({'function_calls': function_calls}, lines_printed);
json_line = json_line['choices'][0]['delta']
@@ -183,6 +223,13 @@ def stream_gpt_completion(data, req_type):
if 'content' in json_line:
content = json_line.get('content')
if content:
buffer += content # accumulate the data
# If you detect a natural breakpoint (e.g., line break or end of a response object), print & count:
if buffer.endswith("\n"): # or some other condition that denotes a breakpoint
lines_printed += count_lines_based_on_width(buffer, terminal_width)
buffer = "" # reset the buffer
gpt_response += content
print(content, end='', flush=True)
@@ -190,10 +237,10 @@ def stream_gpt_completion(data, req_type):
if function_calls['arguments'] != '':
logger.info(f'Response via function call: {function_calls["arguments"]}')
function_calls['arguments'] = load_data_to_json(function_calls['arguments'])
return return_result({'function_calls': function_calls})
return return_result({'function_calls': function_calls}, lines_printed)
logger.info(f'Response message: {gpt_response}')
new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically
return return_result({'text': new_code})
return return_result({'text': new_code}, lines_printed)
def postprocessing(gpt_response, req_type):