mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 21:27:53 -05:00
Experimental support for Anthropic Claude API
This commit is contained in:
@@ -66,6 +66,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if '--api-key' in args:
|
if '--api-key' in args:
|
||||||
os.environ["OPENAI_API_KEY"] = args['--api-key']
|
os.environ["OPENAI_API_KEY"] = args['--api-key']
|
||||||
|
if '--model-name' in args:
|
||||||
|
os.environ['MODEL_NAME'] = args['--model-name']
|
||||||
if '--api-endpoint' in args:
|
if '--api-endpoint' in args:
|
||||||
os.environ["OPENAI_ENDPOINT"] = args['--api-endpoint']
|
os.environ["OPENAI_ENDPOINT"] = args['--api-endpoint']
|
||||||
|
|
||||||
|
|||||||
@@ -109,8 +109,9 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
|
|||||||
{'function_calls': {'name': str, arguments: {...}}}
|
{'function_calls': {'name': str, arguments: {...}}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_name = os.getenv('MODEL_NAME', 'gpt-4')
|
||||||
gpt_data = {
|
gpt_data = {
|
||||||
'model': os.getenv('MODEL_NAME', 'gpt-4'),
|
'model': model_name,
|
||||||
'n': 1,
|
'n': 1,
|
||||||
'temperature': temperature,
|
'temperature': temperature,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
@@ -133,8 +134,18 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
|
|||||||
if prompt_data is not None and function_call_message is not None:
|
if prompt_data is not None and function_call_message is not None:
|
||||||
prompt_data['function_call_message'] = function_call_message
|
prompt_data['function_call_message'] = function_call_message
|
||||||
|
|
||||||
|
if '/' in model_name:
|
||||||
|
model_provider, model_name = model_name.split('/', 1)
|
||||||
|
else:
|
||||||
|
model_provider = 'openai'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = stream_gpt_completion(gpt_data, req_type, project)
|
if model_provider == 'anthropic':
|
||||||
|
if not os.getenv('ANTHROPIC_API_KEY'):
|
||||||
|
os.environ['ANTHROPIC_API_KEY'] = os.getenv('OPENAI_API_KEY')
|
||||||
|
response = stream_anthropic(messages, function_call_message, gpt_data, model_name)
|
||||||
|
else:
|
||||||
|
response = stream_gpt_completion(gpt_data, req_type, project)
|
||||||
|
|
||||||
# Remove JSON schema and any added retry messages
|
# Remove JSON schema and any added retry messages
|
||||||
while len(messages) > messages_length:
|
while len(messages) > messages_length:
|
||||||
@@ -143,7 +154,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
|
|||||||
except TokenLimitError as e:
|
except TokenLimitError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'The request to {os.getenv("ENDPOINT")} API failed: %s', e)
|
logger.error(f'The request to {os.getenv("ENDPOINT")} API for {model_provider}/{model_name} failed: %s', e, exc_info=True)
|
||||||
print(color_red(f'The request to {os.getenv("ENDPOINT")} API failed with error: {e}. Please try again later.'))
|
print(color_red(f'The request to {os.getenv("ENDPOINT")} API failed with error: {e}. Please try again later.'))
|
||||||
if isinstance(e, ApiError):
|
if isinstance(e, ApiError):
|
||||||
raise e
|
raise e
|
||||||
@@ -588,3 +599,48 @@ def postprocessing(gpt_response: str, req_type) -> str:
|
|||||||
|
|
||||||
def load_data_to_json(string):
|
def load_data_to_json(string):
|
||||||
return json.loads(fix_json(string))
|
return json.loads(fix_json(string))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def stream_anthropic(messages, function_call_message, gpt_data, model_name = "claude-3-sonnet-20240229"):
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError as err:
|
||||||
|
raise RuntimeError("The 'anthropic' package is required to use the Anthropic Claude LLM.") from err
|
||||||
|
|
||||||
|
client = anthropic.Anthropic(
|
||||||
|
base_url=os.getenv('ANTHROPIC_ENDPOINT'),
|
||||||
|
)
|
||||||
|
|
||||||
|
claude_system = "You are a software development AI assistant."
|
||||||
|
claude_messages = messages
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
claude_system = messages[0]["content"]
|
||||||
|
claude_messages = messages[1:]
|
||||||
|
|
||||||
|
if len(claude_messages):
|
||||||
|
cm2 = [claude_messages[0]]
|
||||||
|
for i in range(1, len(claude_messages)):
|
||||||
|
if cm2[-1]["role"] == claude_messages[i]["role"]:
|
||||||
|
cm2[-1]["content"] += "\n\n" + claude_messages[i]["content"]
|
||||||
|
else:
|
||||||
|
cm2.append(claude_messages[i])
|
||||||
|
claude_messages = cm2
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
with client.messages.stream(
|
||||||
|
model=model_name,
|
||||||
|
max_tokens=4096,
|
||||||
|
temperature=0.5,
|
||||||
|
system=claude_system,
|
||||||
|
messages=claude_messages,
|
||||||
|
) as stream:
|
||||||
|
for chunk in stream.text_stream:
|
||||||
|
print(chunk, type='stream', end='', flush=True)
|
||||||
|
response += chunk
|
||||||
|
|
||||||
|
if function_call_message is not None:
|
||||||
|
response = clean_json_response(response)
|
||||||
|
assert_json_schema(response, gpt_data["functions"])
|
||||||
|
|
||||||
|
return {"text": response}
|
||||||
|
|||||||
@@ -25,3 +25,4 @@ tiktoken==0.5.2
|
|||||||
urllib3==1.26.7
|
urllib3==1.26.7
|
||||||
wcwidth==0.2.8
|
wcwidth==0.2.8
|
||||||
yaspin==2.5.0
|
yaspin==2.5.0
|
||||||
|
anthropic==0.19.1
|
||||||
|
|||||||
Reference in New Issue
Block a user