Initial implementation for function calling

This commit is contained in:
Zvonimir Sabljic
2023-07-28 10:13:54 +02:00
parent f9f02201cb
commit e9be9b0076

View File

@@ -44,7 +44,7 @@ def get_tokens_in_messages(messages: List[str]) -> int:
return sum(len(tokens) for tokens in tokenized_messages)
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE):
def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE, function_calls=None):
api_key = os.getenv("OPENAI_API_KEY")
# tokens_in_messages = get_tokens_in_messages(messages)
tokens_in_messages = 100
@@ -59,8 +59,16 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO
'stream': True
}
if function_calls is not None:
gpt_data['functions'] = function_calls['definitions']
gpt_data['function_call'] = "auto"
try:
return stream_gpt_completion(gpt_data, req_type)
response = stream_gpt_completion(gpt_data, req_type)
if 'function_calls' in response and function_calls is not None:
function_calls['callback'](response['function_calls']);
elif 'text' in response:
return response['text']
except Exception as e:
print(
'The request to OpenAI API failed. Might be due to GPT being down or due to the too large message. It\'s '
@@ -87,9 +95,11 @@ def stream_gpt_completion(data, req_type):
if response.status_code != 200:
print(f'problem with request: {response.text}')
logger.debug(f'problem with request: {response.text}')
return
return {}
gpt_response = ''
function_calls = { 'name': '', 'arguments': '' }
for line in response.iter_lines():
# Ignore keep-alive new lines
if line:
@@ -104,18 +114,28 @@ def stream_gpt_completion(data, req_type):
try:
json_line = json.loads(line)
if json_line['choices'][0]['finish_reason'] == 'function_call':
function_calls['arguments'] = json.loads(function_calls['arguments'])
return { 'function_calls': function_calls };
json_line = json_line['choices'][0]['delta']
except json.JSONDecodeError:
logger.error(f'Unable to decode line: {line}')
continue # skip to the next line
if 'choices' in json_line:
content = json_line['choices'][0]['delta'].get('content')
if 'function_call' in json_line:
if 'name' in json_line['function_call']:
function_calls['name'] = json_line['function_call']['name']
if 'arguments' in json_line['function_call']:
function_calls['arguments'] += json_line['function_call']['arguments']
if 'content' in json_line:
content = json_line.get('content')
if content:
gpt_response += content
logger.info(f'Response message: {gpt_response}')
new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically
return new_code
return { 'text': new_code }
def postprocessing(gpt_response, req_type):