diff --git a/euclid/utils/llm_connection.py b/euclid/utils/llm_connection.py index 5b08645e..da83de60 100644 --- a/euclid/utils/llm_connection.py +++ b/euclid/utils/llm_connection.py @@ -44,7 +44,7 @@ def get_tokens_in_messages(messages: List[str]) -> int: return sum(len(tokens) for tokens in tokenized_messages) -def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE): +def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TOKENS_FOR_GPT_RESPONSE, function_calls=None): api_key = os.getenv("OPENAI_API_KEY") # tokens_in_messages = get_tokens_in_messages(messages) tokens_in_messages = 100 @@ -59,8 +59,16 @@ def create_gpt_chat_completion(messages: List[dict], req_type, min_tokens=MIN_TO 'stream': True } + if function_calls is not None: + gpt_data['functions'] = function_calls['definitions'] + gpt_data['function_call'] = "auto" + try: - return stream_gpt_completion(gpt_data, req_type) + response = stream_gpt_completion(gpt_data, req_type) + if 'function_calls' in response and function_calls is not None: + function_calls['callback'](response['function_calls']); + elif 'text' in response: + return response['text'] except Exception as e: print( 'The request to OpenAI API failed. Might be due to GPT being down or due to the too large message. It\'s ' @@ -87,9 +95,11 @@ def stream_gpt_completion(data, req_type): if response.status_code != 200: print(f'problem with request: {response.text}') logger.debug(f'problem with request: {response.text}') - return + return {} gpt_response = '' + function_calls = { 'name': '', 'arguments': '' } + for line in response.iter_lines(): # Ignore keep-alive new lines if line: @@ -104,18 +114,28 @@ def stream_gpt_completion(data, req_type): try: json_line = json.loads(line) + if json_line['choices'][0]['finish_reason'] == 'function_call': + function_calls['arguments'] = json.loads(function_calls['arguments']) + return { 'function_calls': function_calls }; + + json_line = json_line['choices'][0]['delta'] except json.JSONDecodeError: logger.error(f'Unable to decode line: {line}') continue # skip to the next line - if 'choices' in json_line: - content = json_line['choices'][0]['delta'].get('content') + if 'function_call' in json_line: + if 'name' in json_line['function_call']: + function_calls['name'] = json_line['function_call']['name'] + if 'arguments' in json_line['function_call']: + function_calls['arguments'] += json_line['function_call']['arguments'] + if 'content' in json_line: + content = json_line.get('content') if content: gpt_response += content logger.info(f'Response message: {gpt_response}') new_code = postprocessing(gpt_response, req_type) # TODO add type dynamically - return new_code + return { 'text': new_code } def postprocessing(gpt_response, req_type):