mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 13:17:55 -05:00
Initial implementation for function calling
This commit is contained in:
@@ -44,7 +44,7 @@ def get_tokens_in_messages(messages: List[str]) -> int:
|
||||
return sum(len(tokens) for tokens in tokenized_messages)
|
||||
|
||||
|
||||
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE):
|
||||
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE, function_calls=None):
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
# tokens_in_messages = get_tokens_in_messages(messages)
|
||||
tokens_in_messages = 100
|
||||
@@ -59,8 +59,16 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
|
||||
'stream': True
|
||||
}
|
||||
|
||||
if function_calls is not None:
|
||||
gpt_data['functions'] = function_calls['definitions']
|
||||
gpt_data['function_call'] = "auto"
|
||||
|
||||
try:
|
||||
return stream_gpt_completion(gpt_data, req_type)
|
||||
response = stream_gpt_completion(gpt_data, req_type)
|
||||
if 'function_calls' in response and function_calls is not None:
|
||||
function_calls['callback'](response['function_calls']);
|
||||
elif 'text' in response:
|
||||
return response['text']
|
||||
except Exception as e:
|
||||
print(
|
||||
'The request to OpenAI API failed. Might be due to GPT being down or due to the too large message. It\'s '
|
||||
@@ -87,9 +95,11 @@ def stream_gpt_completion(data, req_type):
|
||||
if response.status_code != 200:
|
||||
print(f'problem with request: {response.text}')
|
||||
logger.debug(f'problem with request: {response.text}')
|
||||
return
|
||||
return {}
|
||||
|
||||
gpt_response = ''
|
||||
function_calls = { 'name': '', 'arguments': '' }
|
||||
|
||||
for line in response.iter_lines():
|
||||
# Ignore keep-alive new lines
|
||||
if line:
|
||||
@@ -104,18 +114,28 @@ def stream_gpt_completion(data, req_type):
|
||||
|
||||
try:
|
||||
json_line = json.loads(line)
|
||||
if json_line['choices'][0]['finish_reason'] == 'function_call':
|
||||
function_calls['arguments'] = json.loads(function_calls['arguments'])
|
||||
return { 'function_calls': function_calls };
|
||||
|
||||
json_line = json_line['choices'][0]['delta']
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f'Unable to decode line: {line}')
|
||||
continue # skip to the next line
|
||||
|
||||
if 'choices' in json_line:
|
||||
content = json_line['choices'][0]['delta'].get('content')
|
||||
if 'function_call' in json_line:
|
||||
if 'name' in json_line['function_call']:
|
||||
function_calls['name'] = json_line['function_call']['name']
|
||||
if 'arguments' in json_line['function_call']:
|
||||
function_calls['arguments'] += json_line['function_call']['arguments']
|
||||
if 'content' in json_line:
|
||||
content = json_line.get('content')
|
||||
if content:
|
||||
gpt_response += content
|
||||
|
||||
logger.info(f'Response message: {gpt_response}')
|
||||
new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically
|
||||
return new_code
|
||||
return { 'text': new_code }
|
||||
|
||||
|
||||
def postprocessing(gpt_response, req_type):
|
||||
|
||||
Reference in New Issue
Block a user