From 0d7b1424386cfcc43bdeb782368ffb22b2044566 Mon Sep 17 00:00:00 2001 From: Senko Rasic Date: Thu, 7 Mar 2024 20:23:41 +0100 Subject: [PATCH] Experimental support for Anthropic Claude API --- pilot/main.py | 2 ++ pilot/utils/llm_connection.py | 62 +++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/pilot/main.py b/pilot/main.py index 01bb3f73..61462c06 100644 --- a/pilot/main.py +++ b/pilot/main.py @@ -66,6 +66,8 @@ if __name__ == "__main__": if '--api-key' in args: os.environ["OPENAI_API_KEY"] = args['--api-key'] + if '--model-name' in args: + os.environ['MODEL_NAME'] = args['--model-name'] if '--api-endpoint' in args: os.environ["OPENAI_ENDPOINT"] = args['--api-endpoint'] diff --git a/pilot/utils/llm_connection.py b/pilot/utils/llm_connection.py index e0cfe21a..9aadbadb 100644 --- a/pilot/utils/llm_connection.py +++ b/pilot/utils/llm_connection.py @@ -109,8 +109,9 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project, {'function_calls': {'name': str, arguments: {...}}} """ + model_name = os.getenv('MODEL_NAME', 'gpt-4') gpt_data = { - 'model': os.getenv('MODEL_NAME', 'gpt-4'), + 'model': model_name, 'n': 1, 'temperature': temperature, 'top_p': 1, @@ -133,8 +134,18 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project, if prompt_data is not None and function_call_message is not None: prompt_data['function_call_message'] = function_call_message + if '/' in model_name: + model_provider, model_name = model_name.split('/', 1) + else: + model_provider = 'openai' + try: - response = stream_gpt_completion(gpt_data, req_type, project) + if model_provider == 'anthropic': + if not os.getenv('ANTHROPIC_API_KEY'): + os.environ['ANTHROPIC_API_KEY'] = os.getenv('OPENAI_API_KEY') + response = stream_anthropic(messages, function_call_message, gpt_data, model_name) + else: + response = stream_gpt_completion(gpt_data, req_type, project) # Remove JSON schema and any added retry messages while len(messages) > messages_length: @@ -143,7 +154,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project, except TokenLimitError as e: raise e except Exception as e: - logger.error(f'The request to {os.getenv("ENDPOINT")} API failed: %s', e) + logger.error(f'The request to {os.getenv("ENDPOINT")} API for {model_provider}/{model_name} failed: %s', e, exc_info=True) print(color_red(f'The request to {os.getenv("ENDPOINT")} API failed with error: {e}. Please try again later.')) if isinstance(e, ApiError): raise e @@ -588,3 +599,48 @@ def postprocessing(gpt_response: str, req_type) -> str: def load_data_to_json(string): return json.loads(fix_json(string)) + + + +def stream_anthropic(messages, function_call_message, gpt_data, model_name = "claude-3-sonnet-20240229"): + try: + import anthropic + except ImportError as err: + raise RuntimeError("The 'anthropic' package is required to use the Anthropic Claude LLM.") from err + + client = anthropic.Anthropic( + base_url=os.getenv('ANTHROPIC_ENDPOINT'), + ) + + claude_system = "You are a software development AI assistant." + claude_messages = messages + if messages[0]["role"] == "system": + claude_system = messages[0]["content"] + claude_messages = messages[1:] + + if len(claude_messages): + cm2 = [claude_messages[0]] + for i in range(1, len(claude_messages)): + if cm2[-1]["role"] == claude_messages[i]["role"]: + cm2[-1]["content"] += "\n\n" + claude_messages[i]["content"] + else: + cm2.append(claude_messages[i]) + claude_messages = cm2 + + response = "" + with client.messages.stream( + model=model_name, + max_tokens=4096, + temperature=0.5, + system=claude_system, + messages=claude_messages, + ) as stream: + for chunk in stream.text_stream: + print(chunk, type='stream', end='', flush=True) + response += chunk + + if function_call_message is not None: + response = clean_json_response(response) + assert_json_schema(response, gpt_data["functions"]) + + return {"text": response} diff --git a/requirements.txt b/requirements.txt index 141a5b62..04b46922 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ tiktoken==0.5.2 urllib3==1.26.7 wcwidth==0.2.8 yaspin==2.5.0 +anthropic==0.19.1